{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to use custom data and implement custom models and metrics"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    ".. _new-model-tutorial:\n",
    "\n",
    "Building a new model in PyTorch Forecasting is relatively easy. Many things are taken care of automatically\n",
    "\n",
    "* Training, validation and inference is automatically handled for most models - defining the architecture and hyperparameters is sufficient\n",
    "* Dataloading, normalization, re-scaling etc. is provided by the TimeSeriesDataSet\n",
    "* Logging training progress with multiple metrics including plotting examples is automatically taken care of\n",
    "* Masking of entries if different time series have different lengths is automatic\n",
    "\n",
    "However, there a couple of things to keep in mind if you want to make full use of the package. This tutorial first demonstrates how to implement a simple model and then turns to more complicated implementation scenarios.\n",
    "\n",
    "We will answer questions such as\n",
    "\n",
    "* How to transfer an existing PyTorch implementation into PyTorch Forecasting\n",
    "* How to handle data loading and enable different length time series\n",
    "* How to define and use a custom metric\n",
    "* How to handle recurrent networks\n",
    "* How to deal with covariates\n",
    "* How to test new models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a simple, first model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demonstration purposes we will choose a simple fully connected model. It takes a timeseries of size `input_size` as input and outputs a new timeseries of size `output_size`. You can think of this `input_size` encoding steps and `output_size` decoding/prediction steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "os.chdir(\"../../..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 2])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class FullyConnectedModule(nn.Module):\n",
    "    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int):\n",
    "        super().__init__()\n",
    "\n",
    "        # input layer\n",
    "        module_list = [nn.Linear(input_size, hidden_size), nn.ReLU()]\n",
    "        # hidden layers\n",
    "        for _ in range(n_hidden_layers):\n",
    "            module_list.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])\n",
    "        # output layer\n",
    "        module_list.append(nn.Linear(hidden_size, output_size))\n",
    "\n",
    "        self.sequential = nn.Sequential(*module_list)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        # x of shape: batch_size x n_timesteps_in\n",
    "        # output of shape batch_size x n_timesteps_out\n",
    "        return self.sequential(x)\n",
    "\n",
    "\n",
    "# test that network works as intended\n",
    "network = FullyConnectedModule(input_size=5, output_size=2, hidden_size=10, n_hidden_layers=2)\n",
    "x = torch.rand(20, 5)\n",
    "network(x).shape"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "The above model is not yet a PyTorch Forecasting model but it is easy to get there. As this is a simple model, we will use the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`. This base class is modified `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html>`_ with pre-defined hooks for training and validating time series models. The :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` will be discussed later in this tutorial.\n",
    "\n",
    "Either way, the main requirement is for the model to have a ``forward`` method.\n",
    "\n",
    ".. automethod:: pytorch_forecasting.models.base_model.BaseModel.forward\n",
    "    :noindex:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "from pytorch_forecasting.models import BaseModel\n",
    "\n",
    "\n",
    "class FullyConnectedModel(BaseModel):\n",
    "    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, **kwargs):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "        self.network = FullyConnectedModule(\n",
    "            input_size=self.hparams.input_size,\n",
    "            output_size=self.hparams.output_size,\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        network_input = x[\"encoder_cont\"].squeeze(-1)\n",
    "        prediction = self.network(network_input)\n",
    "\n",
    "        # rescale predictions into target space\n",
    "        prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "        # We need to return a dictionary that at least contains the prediction\n",
    "        # The parameter can be directly forwarded from the input.\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a very basic implementation that could be readily used for training. But before we add additional features, let's first have a look how we pass data to this model before we go about initializing our model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Passing data to a model"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    ".. _passing-data:\n",
    "\n",
    "Instead of having to write our own dataloader (which can be rather complicated), we can leverage PyTorch Forecasting's :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` to feed data to our model.\n",
    "In fact, PyTorch Forecasting expects us to use a :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.\n",
    "\n",
    "The data has to be in a specific format to be used by the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. It should be in a pandas `DataFrame` and have a categorical column to identify each series and a integer column to specify the time of the record.\n",
    "\n",
    "Below, we create such a dataset with 30 different observations - 10 for 3 time series."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>value</th>\n",
       "      <th>group</th>\n",
       "      <th>time_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.201798</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.389338</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.285848</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.445960</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.474844</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.012732</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>-0.049512</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>-0.252430</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>-0.274553</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.436116</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>-0.044911</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>-0.127036</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>-0.495227</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.427368</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>-0.343048</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>-0.388719</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.414308</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>-0.194935</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>-0.303643</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.413385</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>-0.499601</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>-0.345062</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.382416</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>-0.094727</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>-0.243780</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>-0.457586</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.205900</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.448471</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.247036</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>-0.219905</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       value  group  time_idx\n",
       "0   0.201798      0         0\n",
       "1   0.389338      0         1\n",
       "2  -0.285848      0         2\n",
       "3  -0.445960      0         3\n",
       "4   0.474844      0         4\n",
       "5   0.012732      0         5\n",
       "6  -0.049512      0         6\n",
       "7  -0.252430      0         7\n",
       "8  -0.274553      0         8\n",
       "9   0.436116      0         9\n",
       "10 -0.044911      1         0\n",
       "11 -0.127036      1         1\n",
       "12 -0.495227      1         2\n",
       "13  0.427368      1         3\n",
       "14 -0.343048      1         4\n",
       "15 -0.388719      1         5\n",
       "16  0.414308      1         6\n",
       "17 -0.194935      1         7\n",
       "18 -0.303643      1         8\n",
       "19  0.413385      1         9\n",
       "20 -0.499601      2         0\n",
       "21 -0.345062      2         1\n",
       "22  0.382416      2         2\n",
       "23 -0.094727      2         3\n",
       "24 -0.243780      2         4\n",
       "25 -0.457586      2         5\n",
       "26  0.205900      2         6\n",
       "27  0.448471      2         7\n",
       "28  0.247036      2         8\n",
       "29 -0.219905      2         9"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "test_data = pd.DataFrame(\n",
    "    dict(\n",
    "        value=np.random.rand(30) - 0.5,\n",
    "        group=np.repeat(np.arange(3), 10),\n",
    "        time_idx=np.tile(np.arange(10), 3),\n",
    "    )\n",
    ")\n",
    "test_data"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Converting it to a :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` is easy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_forecasting import TimeSeriesDataSet\n",
    "\n",
    "# create the dataset from the pandas dataframe\n",
    "dataset = TimeSeriesDataSet(\n",
    "    test_data,\n",
    "    group_ids=[\"group\"],\n",
    "    target=\"value\",\n",
    "    time_idx=\"time_idx\",\n",
    "    min_encoder_length=5,\n",
    "    max_encoder_length=5,\n",
    "    min_prediction_length=2,\n",
    "    max_prediction_length=2,\n",
    "    time_varying_unknown_reals=[\"value\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "We can take a look at all the defaults and settings that were set by PyTorch Forecasting. These are all available as arguments to :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` - see its documentation for more all the details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'time_idx': 'time_idx',\n",
       " 'target': 'value',\n",
       " 'group_ids': ['group'],\n",
       " 'weight': None,\n",
       " 'max_encoder_length': 5,\n",
       " 'min_encoder_length': 5,\n",
       " 'min_prediction_idx': 0,\n",
       " 'min_prediction_length': 2,\n",
       " 'max_prediction_length': 2,\n",
       " 'static_categoricals': [],\n",
       " 'static_reals': [],\n",
       " 'time_varying_known_categoricals': [],\n",
       " 'time_varying_known_reals': [],\n",
       " 'time_varying_unknown_categoricals': [],\n",
       " 'time_varying_unknown_reals': ['value'],\n",
       " 'variable_groups': {},\n",
       " 'constant_fill_strategy': {},\n",
       " 'allow_missing_timesteps': False,\n",
       " 'lags': {},\n",
       " 'add_relative_time_idx': False,\n",
       " 'add_target_scales': False,\n",
       " 'add_encoder_length': False,\n",
       " 'target_normalizer': GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                 scale_by_group=False, transformation=None),\n",
       " 'categorical_encoders': {'__group_id__group': NaNLabelEncoder(add_nan=False, warn=True),\n",
       "  'group': NaNLabelEncoder(add_nan=False, warn=True)},\n",
       " 'scalers': {},\n",
       " 'randomize_length': None,\n",
       " 'predict_mode': False}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.get_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we take a look at the output of the dataloader. It's `x` will be fed to the model's forward method, that is why it is so important to understand it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x = {'encoder_cat': tensor([], size=(4, 5, 0), dtype=torch.int64), 'encoder_cont': tensor([[[-0.2800],\n",
      "         [-1.3852],\n",
      "         [ 1.3842],\n",
      "         [-0.9284],\n",
      "         [-1.0655]],\n",
      "\n",
      "        [[-0.9345],\n",
      "         [ 1.2493],\n",
      "         [-0.1830],\n",
      "         [-0.6304],\n",
      "         [-1.2723]],\n",
      "\n",
      "        [[-0.7567],\n",
      "         [-1.2374],\n",
      "         [ 1.5267],\n",
      "         [ 0.1396],\n",
      "         [-0.0473]],\n",
      "\n",
      "        [[ 1.2701],\n",
      "         [-0.7567],\n",
      "         [-1.2374],\n",
      "         [ 1.5267],\n",
      "         [ 0.1396]]]), 'encoder_target': tensor([[-0.1270, -0.4952,  0.4274, -0.3430, -0.3887],\n",
      "        [-0.3451,  0.3824, -0.0947, -0.2438, -0.4576],\n",
      "        [-0.2858, -0.4460,  0.4748,  0.0127, -0.0495],\n",
      "        [ 0.3893, -0.2858, -0.4460,  0.4748,  0.0127]]), 'encoder_lengths': tensor([5, 5, 5, 5]), 'decoder_cat': tensor([], size=(4, 2, 0), dtype=torch.int64), 'decoder_cont': tensor([[[ 1.3450],\n",
      "         [-0.4838]],\n",
      "\n",
      "        [[ 0.7194],\n",
      "         [ 1.4476]],\n",
      "\n",
      "        [[-0.6564],\n",
      "         [-0.7228]],\n",
      "\n",
      "        [[-0.0473],\n",
      "         [-0.6564]]]), 'decoder_target': tensor([[ 0.4143, -0.1949],\n",
      "        [ 0.2059,  0.4485],\n",
      "        [-0.2524, -0.2746],\n",
      "        [-0.0495, -0.2524]]), 'decoder_lengths': tensor([2, 2, 2, 2]), 'decoder_time_idx': tensor([[6, 7],\n",
      "        [6, 7],\n",
      "        [7, 8],\n",
      "        [6, 7]]), 'groups': tensor([[1],\n",
      "        [2],\n",
      "        [0],\n",
      "        [0]]), 'target_scale': tensor([[-0.0338,  0.3331],\n",
      "        [-0.0338,  0.3331],\n",
      "        [-0.0338,  0.3331],\n",
      "        [-0.0338,  0.3331]])}\n",
      "\n",
      "y = (tensor([[ 0.4143, -0.1949],\n",
      "        [ 0.2059,  0.4485],\n",
      "        [-0.2524, -0.2746],\n",
      "        [-0.0495, -0.2524]]), None)\n",
      "\n",
      "sizes of x =\n",
      "\tencoder_cat = torch.Size([4, 5, 0])\n",
      "\tencoder_cont = torch.Size([4, 5, 1])\n",
      "\tencoder_target = torch.Size([4, 5])\n",
      "\tencoder_lengths = torch.Size([4])\n",
      "\tdecoder_cat = torch.Size([4, 2, 0])\n",
      "\tdecoder_cont = torch.Size([4, 2, 1])\n",
      "\tdecoder_target = torch.Size([4, 2])\n",
      "\tdecoder_lengths = torch.Size([4])\n",
      "\tdecoder_time_idx = torch.Size([4, 2])\n",
      "\tgroups = torch.Size([4, 1])\n",
      "\ttarget_scale = torch.Size([4, 2])\n"
     ]
    }
   ],
   "source": [
    "# convert the dataset to a dataloader\n",
    "dataloader = dataset.to_dataloader(batch_size=4)\n",
    "\n",
    "# and load the first batch\n",
    "x, y = next(iter(dataloader))\n",
    "print(\"x =\", x)\n",
    "print(\"\\ny =\", y)\n",
    "print(\"\\nsizes of x =\")\n",
    "for key, value in x.items():\n",
    "    print(f\"\\t{key} = {value.size()}\")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "To understand it better, we look at documentation of the :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.to_dataloader` method:\n",
    "\n",
    ".. automethod:: pytorch_forecasting.data.timeseries.TimeSeriesDataSet.to_dataloader\n",
    "    :noindex:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This explains why we had to first extract the correct input in our simple `FullyConnectedModel` above before passing it to our `FullyConnectedModule`.\n",
    "As a reminder:\n",
    "       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "    # x is a batch generated based on the TimeSeriesDataset\n",
    "    network_input = x[\"encoder_cont\"].squeeze(-1)\n",
    "    prediction = self.network(network_input)\n",
    "\n",
    "    # rescale predictions into target space\n",
    "    prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "    # We need to return a dictionary that at least contains the prediction\n",
    "    # The parameter can be directly forwarded from the input.\n",
    "    # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "    return self.to_network_output(prediction=prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For such a simple architecture, we can ignore most of the inputs in ``x``. You do not have to worry about moving tensors to specifc GPUs, [PyTorch Lightning](https://pytorch-lightning.readthedocs.io) will take care of this for you.\n",
    "\n",
    "Now, let's check if our model works. We initialize model always with their ``from_dataset()`` method with takes hyperparameters from the dataset, hyperparameters for the model and hyperparameters for the optimizer. Read more about it in the next section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Output(prediction=tensor([[-0.0244,  0.0269],\n",
       "        [-0.0449,  0.0361],\n",
       "        [-0.0271,  0.0307],\n",
       "        [-0.0246,  0.0271]], grad_fn=<AddBackward0>))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = FullyConnectedModel.from_dataset(dataset, input_size=5, output_size=2, hidden_size=10, n_hidden_layers=2)\n",
    "x, y = next(iter(dataloader))\n",
    "model(x)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "If you want to know to which group and time index (at the first prediction) the samples in the batch link to, you can find out by using :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.x_to_index`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>time_idx</th>\n",
       "      <th>group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   time_idx  group\n",
       "0         5      2\n",
       "1         6      2\n",
       "2         5      0\n",
       "3         7      0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.x_to_index(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coupling datasets and models"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "You might have noticed that the encoder and decoder/prediction lengths (5 and 2) are already specified in the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` and we specified them a second time when initializing the model. This might be acceptable for such a simple model but will make it hard for users to understand how to map form the dataset to the model parameters in more complicated settings.\n",
    "This is why we should implement another method in the model: ``from_dataset()``. Typically, a user would always initialize a model from a dataset. The method is also an opportunity to validate that the dataset defined by the user is compatible with your model architecture.\n",
    "\n",
    "While the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` and all PyTorch Forecasting metrics support different length time series, not every network architecture does."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullyConnectedModel(BaseModel):\n",
    "    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, **kwargs):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "        self.network = FullyConnectedModule(\n",
    "            input_size=self.hparams.input_size,\n",
    "            output_size=self.hparams.output_size,\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        network_input = x[\"encoder_cont\"].squeeze(-1)\n",
    "        prediction = self.network(network_input).unsqueeze(-1)\n",
    "\n",
    "        # rescale predictions into target space\n",
    "        prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "        # We need to return a dictionary that at least contains the prediction.\n",
    "        # The parameter can be directly forwarded from the input.\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        new_kwargs = {\n",
    "            \"output_size\": dataset.max_prediction_length,\n",
    "            \"input_size\": dataset.max_encoder_length,\n",
    "        }\n",
    "        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset\n",
    "        # example for dataset validation\n",
    "        assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
    "        assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
    "        assert (\n",
    "            len(dataset.time_varying_known_categoricals) == 0\n",
    "            and len(dataset.time_varying_known_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_categoricals) == 0\n",
    "            and len(dataset.static_categoricals) == 0\n",
    "            and len(dataset.static_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_reals) == 1\n",
    "            and dataset.time_varying_unknown_reals[0] == dataset.target\n",
    "        ), \"Only covariate should be the target in 'time_varying_unknown_reals'\"\n",
    "\n",
    "        return super().from_dataset(dataset, **new_kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's initialize from our dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "   | Name                 | Type                 | Params\n",
      "---------------------------------------------------------------\n",
      "0  | loss                 | SMAPE                | 0     \n",
      "1  | logging_metrics      | ModuleList           | 0     \n",
      "2  | network              | FullyConnectedModule | 302   \n",
      "3  | network.sequential   | Sequential           | 302   \n",
      "4  | network.sequential.0 | Linear               | 60    \n",
      "5  | network.sequential.1 | ReLU                 | 0     \n",
      "6  | network.sequential.2 | Linear               | 110   \n",
      "7  | network.sequential.3 | ReLU                 | 0     \n",
      "8  | network.sequential.4 | Linear               | 110   \n",
      "9  | network.sequential.5 | ReLU                 | 0     \n",
      "10 | network.sequential.6 | Linear               | 22    \n",
      "---------------------------------------------------------------\n",
      "302       Trainable params\n",
      "0         Non-trainable params\n",
      "302       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       SMAPE()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation=None)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = FullyConnectedModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2)\n",
    "model.summarize(\"full\")  # print model summary\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining additional hyperparameters"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "So far, we have kept a wildcard ``**kwargs`` argument in the model initialization signature. We then pass these ``**kwargs`` to the :py:class:`~pytorch_forecasting.models.base_model.BaseModel` using a ``super().__init__(**kwargs)`` call. We can see which additional hyperparameters are available as they are all saved in the ``hparams`` attribute of the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       SMAPE()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation=None)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.hparams"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "While not required, to give the user transparancy over these additional hyperparameters, it is worth passing them explicitly instead of implicitly in ``**kwargs``\n",
    "\n",
    "They are described in detail in the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`. \n",
    "\n",
    ".. automethod:: pytorch_forecasting.models.base_model.BaseModel.__init__\n",
    "    :noindex:\n",
    "    \n",
    "You can simply copy this docstring into your model implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "        BaseModel for timeseries forecasting from which to inherit from\n",
      "\n",
      "        Args:\n",
      "            log_interval (Union[int, float], optional): Batches after which predictions are logged. If < 1.0, will log\n",
      "                multiple entries per batch. Defaults to -1.\n",
      "            log_val_interval (Union[int, float], optional): batches after which predictions for validation are\n",
      "                logged. Defaults to None/log_interval.\n",
      "            learning_rate (float, optional): Learning rate. Defaults to 1e-3.\n",
      "            log_gradient_flow (bool): If to log gradient flow, this takes time and should be only done to diagnose\n",
      "                training failures. Defaults to False.\n",
      "            loss (Metric, optional): metric to optimize, can also be list of metrics. Defaults to SMAPE().\n",
      "            logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.\n",
      "                Defaults to [].\n",
      "            reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10. Defaults\n",
      "                to 1000\n",
      "            reduce_on_plateau_min_lr (float): minimum learning rate for reduce on plateua learning rate scheduler.\n",
      "                Defaults to 1e-5\n",
      "            weight_decay (float): weight decay. Defaults to 0.0.\n",
      "            optimizer_params (Dict[str, Any]): additional parameters for the optimizer. Defaults to {}.\n",
      "            monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder\n",
      "                variables mapping\n",
      "                position (e.g. ``\"0\"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive,\n",
      "                larger numbers add more weight to the constraint vs. the loss but are usually not necessary).\n",
      "                This constraint significantly slows down training. Defaults to {}.\n",
      "            output_transformer (Callable): transformer that takes network output and transforms it to prediction space.\n",
      "                Defaults to None which is equivalent to ``lambda out: out[\"prediction\"]``.\n",
      "            optimizer (str): Optimizer, \"ranger\", \"sgd\", \"adam\", \"adamw\" or class name of optimizer in ``torch.optim``.\n",
      "                Defaults to \"ranger\".\n",
      "        \n"
     ]
    }
   ],
   "source": [
    "print(BaseModel.__init__.__doc__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Classification is a common task and can be easily implemented. In fact, we only have to change the target in our :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` and adjust the number of prediction outputs to reflect the number of classes we want to predict. The changes for the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` are marked below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target</th>\n",
       "      <th>value</th>\n",
       "      <th>group</th>\n",
       "      <th>time_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>A</td>\n",
       "      <td>0.408078</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>A</td>\n",
       "      <td>0.742648</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>B</td>\n",
       "      <td>0.350471</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C</td>\n",
       "      <td>0.540411</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C</td>\n",
       "      <td>0.616381</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>A</td>\n",
       "      <td>0.035933</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>C</td>\n",
       "      <td>0.885356</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>C</td>\n",
       "      <td>0.042899</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>A</td>\n",
       "      <td>0.037230</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>A</td>\n",
       "      <td>0.251297</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>A</td>\n",
       "      <td>0.427348</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>A</td>\n",
       "      <td>0.005809</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>B</td>\n",
       "      <td>0.458371</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>C</td>\n",
       "      <td>0.607380</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>C</td>\n",
       "      <td>0.869184</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>C</td>\n",
       "      <td>0.344104</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>C</td>\n",
       "      <td>0.277108</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>B</td>\n",
       "      <td>0.050816</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>B</td>\n",
       "      <td>0.070120</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>A</td>\n",
       "      <td>0.019318</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>A</td>\n",
       "      <td>0.323767</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>B</td>\n",
       "      <td>0.716880</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>A</td>\n",
       "      <td>0.981400</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>A</td>\n",
       "      <td>0.222833</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>B</td>\n",
       "      <td>0.593633</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>A</td>\n",
       "      <td>0.183884</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>C</td>\n",
       "      <td>0.478769</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>B</td>\n",
       "      <td>0.781213</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>C</td>\n",
       "      <td>0.598958</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>C</td>\n",
       "      <td>0.198800</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   target     value  group  time_idx\n",
       "0       A  0.408078      0         0\n",
       "1       A  0.742648      0         1\n",
       "2       B  0.350471      0         2\n",
       "3       C  0.540411      0         3\n",
       "4       C  0.616381      0         4\n",
       "5       A  0.035933      0         5\n",
       "6       C  0.885356      0         6\n",
       "7       C  0.042899      0         7\n",
       "8       A  0.037230      0         8\n",
       "9       A  0.251297      0         9\n",
       "10      A  0.427348      1         0\n",
       "11      A  0.005809      1         1\n",
       "12      B  0.458371      1         2\n",
       "13      C  0.607380      1         3\n",
       "14      C  0.869184      1         4\n",
       "15      C  0.344104      1         5\n",
       "16      C  0.277108      1         6\n",
       "17      B  0.050816      1         7\n",
       "18      B  0.070120      1         8\n",
       "19      A  0.019318      1         9\n",
       "20      A  0.323767      2         0\n",
       "21      B  0.716880      2         1\n",
       "22      A  0.981400      2         2\n",
       "23      A  0.222833      2         3\n",
       "24      B  0.593633      2         4\n",
       "25      A  0.183884      2         5\n",
       "26      C  0.478769      2         6\n",
       "27      B  0.781213      2         7\n",
       "28      C  0.598958      2         8\n",
       "29      C  0.198800      2         9"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classification_test_data = pd.DataFrame(\n",
    "    dict(\n",
    "        target=np.random.choice([\"A\", \"B\", \"C\"], size=30),  # CHANGING values to predict to a categorical\n",
    "        value=np.random.rand(30),  # INPUT values - see next section on covariates how to use categorical inputs\n",
    "        group=np.repeat(np.arange(3), 10),\n",
    "        time_idx=np.tile(np.arange(10), 3),\n",
    "    )\n",
    ")\n",
    "classification_test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0],\n",
       "        [2, 2],\n",
       "        [0, 0],\n",
       "        [0, 2]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_forecasting.data.encoders import NaNLabelEncoder\n",
    "\n",
    "# create the dataset from the pandas dataframe\n",
    "classification_dataset = TimeSeriesDataSet(\n",
    "    classification_test_data,\n",
    "    group_ids=[\"group\"],\n",
    "    target=\"target\",  # SWITCHING to categorical target\n",
    "    time_idx=\"time_idx\",\n",
    "    min_encoder_length=5,\n",
    "    max_encoder_length=5,\n",
    "    min_prediction_length=2,\n",
    "    max_prediction_length=2,\n",
    "    time_varying_unknown_reals=[\"value\"],\n",
    "    target_normalizer=NaNLabelEncoder(),  # Use the NaNLabelEncoder to encode categorical target\n",
    ")\n",
    "\n",
    "x, y = next(iter(classification_dataset.to_dataloader(batch_size=4)))\n",
    "y[0]  # target values are encoded categories"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "tags": []
   },
   "source": [
    "The keyword argument ``target_normalizer`` is here redundant because the would have detected that a categorical target is used and therefore a :py:class:`~pytorch_forecasting.data.encoders.NaNLabelEncoder` is required."
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Now, we need to modify our implementation of the ``FullyConnectedModel``. In particular, we have to one hyperparameters to the model: ``n_classes`` which determines how\n",
    "many classes there are to predict. Our model will produce a number for each class at each timestep each of which can be converted into probabilities by applying a softmax (over the last dimension). This means we need a total of ``n_decoder_timesteps x n_classes`` predictions. Further, we need to specify the default loss function which we choose to be :py:class:`~pytorch_forecasting.metrics.CrossEntropy`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "   | Name                 | Type                 | Params\n",
      "---------------------------------------------------------------\n",
      "0  | loss                 | SMAPE                | 0     \n",
      "1  | logging_metrics      | ModuleList           | 0     \n",
      "2  | network              | FullyConnectedModule | 346   \n",
      "3  | network.sequential   | Sequential           | 346   \n",
      "4  | network.sequential.0 | Linear               | 60    \n",
      "5  | network.sequential.1 | ReLU                 | 0     \n",
      "6  | network.sequential.2 | Linear               | 110   \n",
      "7  | network.sequential.3 | ReLU                 | 0     \n",
      "8  | network.sequential.4 | Linear               | 110   \n",
      "9  | network.sequential.5 | ReLU                 | 0     \n",
      "10 | network.sequential.6 | Linear               | 66    \n",
      "---------------------------------------------------------------\n",
      "346       Trainable params\n",
      "0         Non-trainable params\n",
      "346       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       CrossEntropy()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_classes\":                  3\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         NaNLabelEncoder(add_nan=False, warn=True)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_forecasting.metrics import CrossEntropy\n",
    "\n",
    "\n",
    "class FullyConnectedClassificationModel(BaseModel):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_size: int,\n",
    "        output_size: int,\n",
    "        hidden_size: int,\n",
    "        n_hidden_layers: int,\n",
    "        n_classes: int,\n",
    "        loss=CrossEntropy(),\n",
    "        **kwargs,\n",
    "    ):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "        self.network = FullyConnectedModule(\n",
    "            input_size=self.hparams.input_size,\n",
    "            output_size=self.hparams.output_size * self.hparams.n_classes,\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        batch_size = x[\"encoder_cont\"].size(0)\n",
    "        network_input = x[\"encoder_cont\"].squeeze(-1)\n",
    "        prediction = self.network(network_input)\n",
    "        # RESHAPE output to batch_size x n_decoder_timesteps x n_classes\n",
    "        prediction = prediction.unsqueeze(-1).view(batch_size, -1, self.hparams.n_classes)\n",
    "\n",
    "        # rescale predictions into target space\n",
    "        prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "        # We need to return a named tuple that at least contains the prediction.\n",
    "        # The parameter can be directly forwarded from the input.\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        assert isinstance(dataset.target_normalizer, NaNLabelEncoder), \"target normalizer has to encode categories\"\n",
    "        new_kwargs = {\n",
    "            \"n_classes\": len(\n",
    "                dataset.target_normalizer.classes_\n",
    "            ),  # ADD number of classes as encoded by the target normalizer\n",
    "            \"output_size\": dataset.max_prediction_length,\n",
    "            \"input_size\": dataset.max_encoder_length,\n",
    "        }\n",
    "        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset\n",
    "        # example for dataset validation\n",
    "        assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
    "        assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
    "        assert (\n",
    "            len(dataset.time_varying_known_categoricals) == 0\n",
    "            and len(dataset.time_varying_known_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_categoricals) == 0\n",
    "            and len(dataset.static_categoricals) == 0\n",
    "            and len(dataset.static_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_reals) == 1\n",
    "        ), \"Only covariate should be in 'time_varying_unknown_reals'\"\n",
    "\n",
    "        return super().from_dataset(dataset, **new_kwargs)\n",
    "\n",
    "\n",
    "model = FullyConnectedClassificationModel.from_dataset(classification_dataset, hidden_size=10, n_hidden_layers=2)\n",
    "model.summarize(\"full\")\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 2, 3])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# passing x through model\n",
    "model(x)[\"prediction\"].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting multiple targets at the same time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training a model to predict multiple targets simulateneously is not difficult to implement. We can even employ mixed targets, i.e. a mix of categorical and continous targets. The first step is to use define a dataframe with multiple targets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target1</th>\n",
       "      <th>target2</th>\n",
       "      <th>group</th>\n",
       "      <th>time_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.117679</td>\n",
       "      <td>0.609905</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.534599</td>\n",
       "      <td>0.045133</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.060970</td>\n",
       "      <td>0.436436</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.044410</td>\n",
       "      <td>0.126944</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.409618</td>\n",
       "      <td>0.516195</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.139532</td>\n",
       "      <td>0.496374</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.727542</td>\n",
       "      <td>0.961093</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.183952</td>\n",
       "      <td>0.299596</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.997208</td>\n",
       "      <td>0.637554</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.587483</td>\n",
       "      <td>0.627798</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.212362</td>\n",
       "      <td>0.955963</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.843313</td>\n",
       "      <td>0.747749</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.901079</td>\n",
       "      <td>0.259364</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.690811</td>\n",
       "      <td>0.927396</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.018312</td>\n",
       "      <td>0.874168</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.388948</td>\n",
       "      <td>0.390768</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.958272</td>\n",
       "      <td>0.020205</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.609807</td>\n",
       "      <td>0.608232</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.038265</td>\n",
       "      <td>0.599644</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.867473</td>\n",
       "      <td>0.382473</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.594093</td>\n",
       "      <td>0.667170</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.233055</td>\n",
       "      <td>0.861108</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.577353</td>\n",
       "      <td>0.550609</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.763409</td>\n",
       "      <td>0.712799</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.626900</td>\n",
       "      <td>0.562910</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.322371</td>\n",
       "      <td>0.250461</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.410469</td>\n",
       "      <td>0.270854</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.053525</td>\n",
       "      <td>0.757942</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.306615</td>\n",
       "      <td>0.744437</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.048932</td>\n",
       "      <td>0.573059</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     target1   target2  group  time_idx\n",
       "0   0.117679  0.609905      0         0\n",
       "1   0.534599  0.045133      0         1\n",
       "2   0.060970  0.436436      0         2\n",
       "3   0.044410  0.126944      0         3\n",
       "4   0.409618  0.516195      0         4\n",
       "5   0.139532  0.496374      0         5\n",
       "6   0.727542  0.961093      0         6\n",
       "7   0.183952  0.299596      0         7\n",
       "8   0.997208  0.637554      0         8\n",
       "9   0.587483  0.627798      0         9\n",
       "10  0.212362  0.955963      1         0\n",
       "11  0.843313  0.747749      1         1\n",
       "12  0.901079  0.259364      1         2\n",
       "13  0.690811  0.927396      1         3\n",
       "14  0.018312  0.874168      1         4\n",
       "15  0.388948  0.390768      1         5\n",
       "16  0.958272  0.020205      1         6\n",
       "17  0.609807  0.608232      1         7\n",
       "18  0.038265  0.599644      1         8\n",
       "19  0.867473  0.382473      1         9\n",
       "20  0.594093  0.667170      2         0\n",
       "21  0.233055  0.861108      2         1\n",
       "22  0.577353  0.550609      2         2\n",
       "23  0.763409  0.712799      2         3\n",
       "24  0.626900  0.562910      2         4\n",
       "25  0.322371  0.250461      2         5\n",
       "26  0.410469  0.270854      2         6\n",
       "27  0.053525  0.757942      2         7\n",
       "28  0.306615  0.744437      2         8\n",
       "29  0.048932  0.573059      2         9"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multi_target_test_data = pd.DataFrame(\n",
    "    dict(\n",
    "        target1=np.random.rand(30),\n",
    "        target2=np.random.rand(30),\n",
    "        group=np.repeat(np.arange(3), 10),\n",
    "        time_idx=np.tile(np.arange(10), 3),\n",
    "    )\n",
    ")\n",
    "multi_target_test_data"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "We can then simply pass a list to ``target`` keyword of the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. The class will choose reasonable defaults for normalizing the targets but we can also specify the normalizer explicitly by assigning an instance of :py:class:`~pytorch_forecasting.data.encoders.MultiNormalizer` to the ``target_normalizer`` keyword - for fun, lets use different ways of normalization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.0535, 0.3066],\n",
       "         [0.7275, 0.1840],\n",
       "         [0.0383, 0.8675],\n",
       "         [0.4105, 0.0535]]),\n",
       " tensor([[0.7579, 0.7444],\n",
       "         [0.9611, 0.2996],\n",
       "         [0.5996, 0.3825],\n",
       "         [0.2709, 0.7579]])]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_forecasting.data.encoders import EncoderNormalizer, MultiNormalizer, TorchNormalizer\n",
    "\n",
    "# create the dataset from the pandas dataframe\n",
    "multi_target_dataset = TimeSeriesDataSet(\n",
    "    multi_target_test_data,\n",
    "    group_ids=[\"group\"],\n",
    "    target=[\"target1\", \"target2\"],  # USING two targets\n",
    "    time_idx=\"time_idx\",\n",
    "    min_encoder_length=5,\n",
    "    max_encoder_length=5,\n",
    "    min_prediction_length=2,\n",
    "    max_prediction_length=2,\n",
    "    time_varying_unknown_reals=[\"target1\", \"target2\"],\n",
    "    target_normalizer=MultiNormalizer(\n",
    "        [EncoderNormalizer(), TorchNormalizer()]\n",
    "    ),  # Use the NaNLabelEncoder to encode categorical target\n",
    ")\n",
    "\n",
    "x, y = next(iter(multi_target_dataset.to_dataloader(batch_size=4)))\n",
    "y[0]  # target values are a list of targets"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Using multiple targets leads to a slightly different ``x`` and ``y`` of the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`'s dataloader.\n",
    "``y`` is still a tuple of target and weight but the target is now a list of tensors. So is the ``target_scale``, the ``encoder_target`` and the ``decoder_target`` in ``x``.\n",
    "\n",
    "For this reason not every model is automatically suited to deal with multiple targets. However, it is (very often) fairly simple to extend a model to output a list of tensors (for each target) as opposed to just one tensor (for one target). We will now modify our ``FullyConnectedModel`` to work with one or more targets.\n",
    "\n",
    "As we use multiple targets, we need to define a loss function that can handle them. The :py:class:`~pytorch_forecasting.metrics.MultiLoss` is exactly built for that purpose. It also allows weighing the losses differently. Soley for demonstration purposes, we decide to optimize the mean absolute error for the first and the symmetric mean average percentage error for the second target. We weight the error on the first target double as high as the error on the second target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "   | Name                 | Type                 | Params\n",
      "---------------------------------------------------------------\n",
      "0  | loss                 | MultiLoss            | 0     \n",
      "1  | logging_metrics      | ModuleList           | 0     \n",
      "2  | network              | FullyConnectedModule | 374   \n",
      "3  | network.sequential   | Sequential           | 374   \n",
      "4  | network.sequential.0 | Linear               | 110   \n",
      "5  | network.sequential.1 | ReLU                 | 0     \n",
      "6  | network.sequential.2 | Linear               | 110   \n",
      "7  | network.sequential.3 | ReLU                 | 0     \n",
      "8  | network.sequential.4 | Linear               | 110   \n",
      "9  | network.sequential.5 | ReLU                 | 0     \n",
      "10 | network.sequential.6 | Linear               | 44    \n",
      "---------------------------------------------------------------\n",
      "374       Trainable params\n",
      "0         Non-trainable params\n",
      "374       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       MultiLoss(2 * MAE(), SMAPE())\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         MultiNormalizer(normalizers=[EncoderNormalizer(center=True, eps=1e-08,\n",
       "                                               method='standard',\n",
       "                                               transformation=None),\n",
       "                             TorchNormalizer(center=True, eps=1e-08,\n",
       "                                             method='standard',\n",
       "                                             transformation=None)])\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"target_sizes\":               [1, 1]\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import List, Union\n",
    "\n",
    "from pytorch_forecasting.metrics import MAE, SMAPE, MultiLoss\n",
    "from pytorch_forecasting.utils import to_list\n",
    "\n",
    "\n",
    "class FullyConnectedMultiTargetModel(BaseModel):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_size: int,\n",
    "        output_size: int,\n",
    "        hidden_size: int,\n",
    "        n_hidden_layers: int,\n",
    "        target_sizes: Union[int, List[int]] = [],\n",
    "        **kwargs,\n",
    "    ):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "        self.network = FullyConnectedModule(\n",
    "            input_size=self.hparams.input_size * len(to_list(self.hparams.target_sizes)),\n",
    "            output_size=self.hparams.output_size * sum(to_list(self.hparams.target_sizes)),\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        batch_size = x[\"encoder_cont\"].size(0)\n",
    "        network_input = x[\"encoder_cont\"].view(batch_size, -1)\n",
    "        prediction = self.network(network_input)\n",
    "        # RESHAPE output to batch_size x n_decoder_timesteps x sum_of_target_sizes\n",
    "        prediction = prediction.unsqueeze(-1).view(batch_size, self.hparams.output_size, sum(self.hparams.target_sizes))\n",
    "        # RESHAPE into list of batch_size x n_decoder_timesteps x target_sizes[i] where i=1..len(target_sizes)\n",
    "        stops = np.cumsum(self.hparams.target_sizes)\n",
    "        starts = stops - self.hparams.target_sizes\n",
    "        prediction = [prediction[..., start:stop] for start, stop in zip(starts, stops)]\n",
    "        if isinstance(self.hparams.target_sizes, int):  # only one target\n",
    "            prediction = prediction[0]\n",
    "\n",
    "        # rescale predictions into target space\n",
    "        prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "        # We need to return a named tuple that at least contains the prediction.\n",
    "        # The parameter can be directly forwarded from the input.\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        # By default only handle targets of size one here, categorical targets would be of larger size\n",
    "        new_kwargs = {\n",
    "            \"target_sizes\": [1] * len(to_list(dataset.target)),\n",
    "            \"output_size\": dataset.max_prediction_length,\n",
    "            \"input_size\": dataset.max_encoder_length,\n",
    "        }\n",
    "        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset\n",
    "        # example for dataset validation\n",
    "        assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
    "        assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
    "        assert (\n",
    "            len(dataset.time_varying_known_categoricals) == 0\n",
    "            and len(dataset.time_varying_known_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_categoricals) == 0\n",
    "            and len(dataset.static_categoricals) == 0\n",
    "            and len(dataset.static_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_reals)\n",
    "            == len(dataset.target_names)  # Expect as as many unknown reals as targets\n",
    "        ), \"Only covariate should be in 'time_varying_unknown_reals'\"\n",
    "\n",
    "        return super().from_dataset(dataset, **new_kwargs)\n",
    "\n",
    "\n",
    "model = FullyConnectedMultiTargetModel.from_dataset(\n",
    "    multi_target_dataset,\n",
    "    hidden_size=10,\n",
    "    n_hidden_layers=2,\n",
    "    loss=MultiLoss(metrics=[MAE(), SMAPE()], weights=[2.0, 1.0]),\n",
    ")\n",
    "model.summarize(\"full\")\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's pass some data through our model and calculate the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Output(prediction=[tensor([[[0.4985],\n",
       "         [0.5652]],\n",
       "\n",
       "        [[0.1833],\n",
       "         [0.2609]],\n",
       "\n",
       "        [[0.4349],\n",
       "         [0.5695]],\n",
       "\n",
       "        [[0.4508],\n",
       "         [0.5309]]], grad_fn=<AddBackward0>), tensor([[[0.5504],\n",
       "         [0.6183]],\n",
       "\n",
       "        [[0.5397],\n",
       "         [0.6063]],\n",
       "\n",
       "        [[0.5433],\n",
       "         [0.6266]],\n",
       "\n",
       "        [[0.5444],\n",
       "         [0.6143]]], dtype=torch.float64, grad_fn=<AddBackward0>)])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = model(x)\n",
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.0348, dtype=torch.float64, grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.loss(out[\"prediction\"], y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using covariates"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Now that we have established the basics, we can move on to more advanced use cases, e.g. how can we make use of covariates - static and continuous alike. We can leverage the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` for this. The difference to the :py:class:`~pytorch_forecasting.models.base_model.BaseModel` is a :py:meth:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates.from_dataset` method that pre-defines hyperparameters for architectures with covariates.\n",
    "\n",
    ".. autoclass:: pytorch_forecasting.models.base_model.BaseModelWithCovariates\n",
    "    :noindex:\n",
    "    :members: from_dataset\n",
    "    \n",
    "\n",
    "Here is a from the BaseModelWithCovariates docstring to copy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    Model with additional methods using covariates.\n",
      "\n",
      "    Assumes the following hyperparameters:\n",
      "\n",
      "    Args:\n",
      "        static_categoricals (List[str]): names of static categorical variables\n",
      "        static_reals (List[str]): names of static continuous variables\n",
      "        time_varying_categoricals_encoder (List[str]): names of categorical variables for encoder\n",
      "        time_varying_categoricals_decoder (List[str]): names of categorical variables for decoder\n",
      "        time_varying_reals_encoder (List[str]): names of continuous variables for encoder\n",
      "        time_varying_reals_decoder (List[str]): names of continuous variables for decoder\n",
      "        x_reals (List[str]): order of continuous variables in tensor passed to forward function\n",
      "        x_categoricals (List[str]): order of categorical variables in tensor passed to forward function\n",
      "        embedding_sizes (Dict[str, Tuple[int, int]]): dictionary mapping categorical variables to tuple of integers\n",
      "            where the first integer denotes the number of categorical classes and the second the embedding size\n",
      "        embedding_labels (Dict[str, List[str]]): dictionary mapping (string) indices to list of categorical labels\n",
      "        embedding_paddings (List[str]): names of categorical variables for which label 0 is always mapped to an\n",
      "             embedding vector filled with zeros\n",
      "        categorical_groups (Dict[str, List[str]]): dictionary of categorical variables that are grouped together and\n",
      "            can also take multiple values simultaneously (e.g. holiday during octoberfest). They should be implemented\n",
      "            as bag of embeddings\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "from pytorch_forecasting.models.base_model import BaseModelWithCovariates\n",
    "\n",
    "print(BaseModelWithCovariates.__doc__)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "We will now implement the model. A helpful module is the :py:class:`~pytorch_forecasting.models.nn.embeddings.MultiEmbedding` which can be used to embed categorical features. It is compliant with he :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`, i.e. it supports bags of embeddings that are useful for embeddings where multiple categories can occur at the same time such holidays. Again, we will create a fully-connected network. It is easy to recycle our ``FullyConnectedModule`` by simply replacing setting ``input_size`` to the number of encoder time steps times the number of features instead of simply the number of encoder time steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, List, Tuple\n",
    "\n",
    "from pytorch_forecasting.models.nn import MultiEmbedding\n",
    "\n",
    "\n",
    "class FullyConnectedModelWithCovariates(BaseModelWithCovariates):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_size: int,\n",
    "        output_size: int,\n",
    "        hidden_size: int,\n",
    "        n_hidden_layers: int,\n",
    "        x_reals: List[str],\n",
    "        x_categoricals: List[str],\n",
    "        embedding_sizes: Dict[str, Tuple[int, int]],\n",
    "        embedding_labels: Dict[str, List[str]],\n",
    "        static_categoricals: List[str],\n",
    "        static_reals: List[str],\n",
    "        time_varying_categoricals_encoder: List[str],\n",
    "        time_varying_categoricals_decoder: List[str],\n",
    "        time_varying_reals_encoder: List[str],\n",
    "        time_varying_reals_decoder: List[str],\n",
    "        embedding_paddings: List[str],\n",
    "        categorical_groups: Dict[str, List[str]],\n",
    "        **kwargs,\n",
    "    ):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        # create embedder - can be fed with x[\"encoder_cat\"] or x[\"decoder_cat\"] and will return\n",
    "        # dictionary of category names mapped to embeddings\n",
    "        self.input_embeddings = MultiEmbedding(\n",
    "            embedding_sizes=self.hparams.embedding_sizes,\n",
    "            categorical_groups=self.hparams.categorical_groups,\n",
    "            embedding_paddings=self.hparams.embedding_paddings,\n",
    "            x_categoricals=self.hparams.x_categoricals,\n",
    "            max_embedding_size=self.hparams.hidden_size,\n",
    "        )\n",
    "\n",
    "        # calculate the size of all concatenated embeddings + continous variables\n",
    "        n_features = sum(\n",
    "            embedding_size for classes_size, embedding_size in self.hparams.embedding_sizes.values()\n",
    "        ) + len(self.reals)\n",
    "\n",
    "        # create network that will be fed with continious variables and embeddings\n",
    "        self.network = FullyConnectedModule(\n",
    "            input_size=self.hparams.input_size * n_features,\n",
    "            output_size=self.hparams.output_size,\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        batch_size = x[\"encoder_lengths\"].size(0)\n",
    "        embeddings = self.input_embeddings(x[\"encoder_cat\"])  # returns dictionary with embedding tensors\n",
    "        network_input = torch.cat(\n",
    "            [x[\"encoder_cont\"]]\n",
    "            + [\n",
    "                emb\n",
    "                for name, emb in embeddings.items()\n",
    "                if name in self.encoder_variables or name in self.static_variables\n",
    "            ],\n",
    "            dim=-1,\n",
    "        )\n",
    "        prediction = self.network(network_input.view(batch_size, -1))\n",
    "\n",
    "        # rescale predictions into target space\n",
    "        prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "        # We need to return a dictionary that at least contains the prediction.\n",
    "        # The parameter can be directly forwarded from the input.\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        new_kwargs = {\n",
    "            \"output_size\": dataset.max_prediction_length,\n",
    "            \"input_size\": dataset.max_encoder_length,\n",
    "        }\n",
    "        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset\n",
    "        # example for dataset validation\n",
    "        assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
    "        assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
    "\n",
    "        return super().from_dataset(dataset, **new_kwargs)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "We have used here additional hooks available through the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` such as ``self.static_variables`` or ``self.encoder_variables`` that can be readily determined from the hyperparameters. See the documentation of the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` class for all available additions to the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`.\n",
    "\n",
    "When the model receives its input `x`, you can use the hyperparameters and linked to variables and the additional variables by the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` to identify the different variables. This is important as ``x[\"encoder_cat\"].size(2) == x[\"decoder_cat\"].size(2)`` and ``x[\"encoder_cont\"].size(2) == x[\"decoder_cont\"].size(2)``. This means all variables are passed to the encoder and decoder even if some are not allowed to be used by the decoder as they are not known in the future. The order of variables in ``x[\"encoder_cont\"]`` / ``x[\"decoder_cont\"]`` and ``x[\"encoder_cat\"]`` / ``x[\"decoder_cat\"]``is determined by the hyperparameters ``x_reals`` and ``x_categoricals``. Consequently, you can idenify, for example, the position of all continuous decoder variables with ``[self.hparams.x_reals.index(name) for name in self.hparams.time_varying_reals_decoder]``."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the model does not make use of the known covariates in the decoder - this is obviously suboptimal but not scope of this tutorial. Anyways, let us create a new dataset with categorical variables and see how the model can be instantiated from it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>value</th>\n",
       "      <th>group</th>\n",
       "      <th>time_idx</th>\n",
       "      <th>categorical_covariate</th>\n",
       "      <th>real_covariate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.103204</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>b</td>\n",
       "      <td>0.685154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.779360</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>b</td>\n",
       "      <td>0.070320</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.926564</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>a</td>\n",
       "      <td>0.688759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.534834</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>a</td>\n",
       "      <td>0.142168</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.662401</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>b</td>\n",
       "      <td>0.987386</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.981731</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>a</td>\n",
       "      <td>0.909865</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.310089</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>a</td>\n",
       "      <td>0.040515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.507049</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>a</td>\n",
       "      <td>0.249007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.313872</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>b</td>\n",
       "      <td>0.357797</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.071876</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "      <td>b</td>\n",
       "      <td>0.454941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.703754</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>a</td>\n",
       "      <td>0.291920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.393297</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>b</td>\n",
       "      <td>0.077429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.111496</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>b</td>\n",
       "      <td>0.743708</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.769912</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>b</td>\n",
       "      <td>0.598697</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.242925</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>b</td>\n",
       "      <td>0.360077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.729829</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>a</td>\n",
       "      <td>0.094971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.904733</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>b</td>\n",
       "      <td>0.019580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.490206</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "      <td>b</td>\n",
       "      <td>0.369545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.912757</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>a</td>\n",
       "      <td>0.566772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.485278</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "      <td>a</td>\n",
       "      <td>0.581759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.637475</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>a</td>\n",
       "      <td>0.002411</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.627253</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>a</td>\n",
       "      <td>0.726943</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.715399</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>a</td>\n",
       "      <td>0.691760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.845473</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>b</td>\n",
       "      <td>0.820702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.056837</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>b</td>\n",
       "      <td>0.690101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.217073</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>a</td>\n",
       "      <td>0.664176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.701934</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>b</td>\n",
       "      <td>0.941609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.663467</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "      <td>b</td>\n",
       "      <td>0.453616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.819870</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>a</td>\n",
       "      <td>0.690301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.615967</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>b</td>\n",
       "      <td>0.708300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       value group  time_idx categorical_covariate  real_covariate\n",
       "0   0.103204     0         0                     b        0.685154\n",
       "1   0.779360     0         1                     b        0.070320\n",
       "2   0.926564     0         2                     a        0.688759\n",
       "3   0.534834     0         3                     a        0.142168\n",
       "4   0.662401     0         4                     b        0.987386\n",
       "5   0.981731     0         5                     a        0.909865\n",
       "6   0.310089     0         6                     a        0.040515\n",
       "7   0.507049     0         7                     a        0.249007\n",
       "8   0.313872     0         8                     b        0.357797\n",
       "9   0.071876     0         9                     b        0.454941\n",
       "10  0.703754     1         0                     a        0.291920\n",
       "11  0.393297     1         1                     b        0.077429\n",
       "12  0.111496     1         2                     b        0.743708\n",
       "13  0.769912     1         3                     b        0.598697\n",
       "14  0.242925     1         4                     b        0.360077\n",
       "15  0.729829     1         5                     a        0.094971\n",
       "16  0.904733     1         6                     b        0.019580\n",
       "17  0.490206     1         7                     b        0.369545\n",
       "18  0.912757     1         8                     a        0.566772\n",
       "19  0.485278     1         9                     a        0.581759\n",
       "20  0.637475     2         0                     a        0.002411\n",
       "21  0.627253     2         1                     a        0.726943\n",
       "22  0.715399     2         2                     a        0.691760\n",
       "23  0.845473     2         3                     b        0.820702\n",
       "24  0.056837     2         4                     b        0.690101\n",
       "25  0.217073     2         5                     a        0.664176\n",
       "26  0.701934     2         6                     b        0.941609\n",
       "27  0.663467     2         7                     b        0.453616\n",
       "28  0.819870     2         8                     a        0.690301\n",
       "29  0.615967     2         9                     b        0.708300"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from pytorch_forecasting import TimeSeriesDataSet\n",
    "\n",
    "test_data_with_covariates = pd.DataFrame(\n",
    "    dict(\n",
    "        # as before\n",
    "        value=np.random.rand(30),\n",
    "        group=np.repeat(np.arange(3), 10),\n",
    "        time_idx=np.tile(np.arange(10), 3),\n",
    "        # now adding covariates\n",
    "        categorical_covariate=np.random.choice([\"a\", \"b\"], size=30),\n",
    "        real_covariate=np.random.rand(30),\n",
    "    )\n",
    ").astype(\n",
    "    dict(group=str)\n",
    ")  # categorical covariates have to be of string type\n",
    "test_data_with_covariates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "   | Name                                              | Type                 | Params\n",
      "--------------------------------------------------------------------------------------------\n",
      "0  | loss                                              | SMAPE                | 0     \n",
      "1  | logging_metrics                                   | ModuleList           | 0     \n",
      "2  | input_embeddings                                  | MultiEmbedding       | 11    \n",
      "3  | input_embeddings.embeddings                       | ModuleDict           | 11    \n",
      "4  | input_embeddings.embeddings.group                 | Embedding            | 9     \n",
      "5  | input_embeddings.embeddings.categorical_covariate | Embedding            | 2     \n",
      "6  | network                                           | FullyConnectedModule | 552   \n",
      "7  | network.sequential                                | Sequential           | 552   \n",
      "8  | network.sequential.0                              | Linear               | 310   \n",
      "9  | network.sequential.1                              | ReLU                 | 0     \n",
      "10 | network.sequential.2                              | Linear               | 110   \n",
      "11 | network.sequential.3                              | ReLU                 | 0     \n",
      "12 | network.sequential.4                              | Linear               | 110   \n",
      "13 | network.sequential.5                              | ReLU                 | 0     \n",
      "14 | network.sequential.6                              | Linear               | 22    \n",
      "--------------------------------------------------------------------------------------------\n",
      "563       Trainable params\n",
      "0         Non-trainable params\n",
      "563       Total params\n",
      "0.002     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"categorical_groups\":                {}\n",
       "\"embedding_labels\":                  {'group': {'0': 0, '1': 1, '2': 2}, 'categorical_covariate': {'a': 0, 'b': 1}}\n",
       "\"embedding_paddings\":                []\n",
       "\"embedding_sizes\":                   {'group': [3, 3], 'categorical_covariate': [2, 1]}\n",
       "\"hidden_size\":                       10\n",
       "\"input_size\":                        5\n",
       "\"learning_rate\":                     0.001\n",
       "\"log_gradient_flow\":                 False\n",
       "\"log_interval\":                      -1\n",
       "\"log_val_interval\":                  -1\n",
       "\"logging_metrics\":                   ModuleList()\n",
       "\"loss\":                              SMAPE()\n",
       "\"monotone_constaints\":               {}\n",
       "\"n_hidden_layers\":                   2\n",
       "\"optimizer\":                         ranger\n",
       "\"optimizer_params\":                  None\n",
       "\"output_size\":                       2\n",
       "\"output_transformer\":                GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation='relu')\n",
       "\"reduce_on_plateau_min_lr\":          1e-05\n",
       "\"reduce_on_plateau_patience\":        1000\n",
       "\"static_categoricals\":               ['group']\n",
       "\"static_reals\":                      []\n",
       "\"time_varying_categoricals_decoder\": ['categorical_covariate']\n",
       "\"time_varying_categoricals_encoder\": ['categorical_covariate']\n",
       "\"time_varying_reals_decoder\":        ['real_covariate']\n",
       "\"time_varying_reals_encoder\":        ['real_covariate', 'value']\n",
       "\"weight_decay\":                      0.0\n",
       "\"x_categoricals\":                    ['group', 'categorical_covariate']\n",
       "\"x_reals\":                           ['real_covariate', 'value']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# create the dataset from the pandas dataframe\n",
    "dataset_with_covariates = TimeSeriesDataSet(\n",
    "    test_data_with_covariates,\n",
    "    group_ids=[\"group\"],\n",
    "    target=\"value\",\n",
    "    time_idx=\"time_idx\",\n",
    "    min_encoder_length=5,\n",
    "    max_encoder_length=5,\n",
    "    min_prediction_length=2,\n",
    "    max_prediction_length=2,\n",
    "    time_varying_unknown_reals=[\"value\"],\n",
    "    time_varying_known_reals=[\"real_covariate\"],\n",
    "    time_varying_known_categoricals=[\"categorical_covariate\"],\n",
    "    static_categoricals=[\"group\"],\n",
    ")\n",
    "\n",
    "model = FullyConnectedModelWithCovariates.from_dataset(dataset_with_covariates, hidden_size=10, n_hidden_layers=2)\n",
    "model.summarize(\"full\")  # print model summary\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To test that the model could be trained, pass a sample batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Output(prediction=tensor([[0.5677, 0.5771],\n",
       "        [0.5779, 0.6040],\n",
       "        [0.5705, 0.5793],\n",
       "        [0.5616, 0.5751]], grad_fn=<ClampBackward>))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x, y = next(iter(dataset_with_covariates.to_dataloader(batch_size=4)))  # generate batch\n",
    "model(x)  # pass batch through model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing an autoregressive / recurrent model"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Often time series models are autoregressive, i.e. one does not make `n` predictions for all future steps in one function call but predicts ``n`` times one step ahead. PyTorch Forecasting comes with a\n",
    ":py:class:`~pytorch_forecasting.models.base_model.AutoRegressiveBaseModel` and a :py:class:`~pytorch_forecasting.models.base_model.AutoRegressiveBaseModelWithCovariates` for such models.\n",
    "\n",
    ".. autoclass:: pytorch_forecasting.models.base_model.AutoRegressiveBaseModel\n",
    "    :noindex:\n",
    "\n",
    "In this section, we will implement a simple LSTM model that could be easily extended to work with covariates. Note that because we do not handle covariates, lagged targets cannot be incorporated in this network. We use an implementation of the :py:class:`~pytorch_forecasting.models.nn.rnn.LSTM` that can handle zero-length sequences but otherwise 100% mirrors the PyTorch-native implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name            | Type       | Params\n",
      "-----------------------------------------------\n",
      "0 | loss            | SMAPE      | 0     \n",
      "1 | logging_metrics | ModuleList | 0     \n",
      "2 | lstm            | LSTM       | 1.4 K \n",
      "3 | output_layer    | Linear     | 11    \n",
      "-----------------------------------------------\n",
      "1.4 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.4 K     Total params\n",
      "0.006     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"dropout\":                    0.1\n",
       "\"hidden_size\":                10\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       SMAPE()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_layers\":                   2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_transformer\":         GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation=None)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"target\":                     value\n",
       "\"target_lags\":                {}\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.nn.utils import rnn\n",
    "\n",
    "from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel\n",
    "from pytorch_forecasting.models.nn import LSTM\n",
    "\n",
    "\n",
    "class LSTMModel(AutoRegressiveBaseModel):\n",
    "    def __init__(\n",
    "        self,\n",
    "        target: str,\n",
    "        target_lags: Dict[str, Dict[str, int]],\n",
    "        n_layers: int,\n",
    "        hidden_size: int,\n",
    "        dropout: float = 0.1,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        # arguments target and target_lags are required for autoregressive models\n",
    "        # even though target_lags cannot be used without covariates\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        # use version of LSTM that can handle zero-length sequences\n",
    "        self.lstm = LSTM(\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            input_size=1,\n",
    "            num_layers=self.hparams.n_layers,\n",
    "            dropout=self.hparams.dropout,\n",
    "            batch_first=True,\n",
    "        )\n",
    "        self.output_layer = nn.Linear(self.hparams.hidden_size, 1)\n",
    "\n",
    "    def encode(self, x: Dict[str, torch.Tensor]):\n",
    "        # we need at least one encoding step as because the target needs to be lagged by one time step\n",
    "        # because we use the custom LSTM, we do not have to require encoder lengths of > 1\n",
    "        # but can handle lengths of >= 1\n",
    "        assert x[\"encoder_lengths\"].min() >= 1\n",
    "        input_vector = x[\"encoder_cont\"].clone()\n",
    "        # lag target by one\n",
    "        input_vector[..., self.target_positions] = torch.roll(\n",
    "            input_vector[..., self.target_positions], shifts=1, dims=1\n",
    "        )\n",
    "        input_vector = input_vector[:, 1:]  # first time step cannot be used because of lagging\n",
    "\n",
    "        # determine effective encoder_length length\n",
    "        effective_encoder_lengths = x[\"encoder_lengths\"] - 1\n",
    "        # run through LSTM network\n",
    "        _, hidden_state = self.lstm(\n",
    "            input_vector, lengths=effective_encoder_lengths, enforce_sorted=False  # passing the lengths directly\n",
    "        )  # second ouput is not needed (hidden state)\n",
    "        return hidden_state\n",
    "\n",
    "    def decode(self, x: Dict[str, torch.Tensor], hidden_state):\n",
    "        # again lag target by one\n",
    "        input_vector = x[\"decoder_cont\"].clone()\n",
    "        input_vector[..., self.target_positions] = torch.roll(\n",
    "            input_vector[..., self.target_positions], shifts=1, dims=1\n",
    "        )\n",
    "        # but this time fill in missing target from encoder_cont at the first time step instead of throwing it away\n",
    "        last_encoder_target = x[\"encoder_cont\"][\n",
    "            torch.arange(x[\"encoder_cont\"].size(0), device=x[\"encoder_cont\"].device),\n",
    "            x[\"encoder_lengths\"] - 1,\n",
    "            self.target_positions.unsqueeze(-1),\n",
    "        ].T\n",
    "        input_vector[:, 0, self.target_positions] = last_encoder_target\n",
    "\n",
    "        if self.training:  # training mode\n",
    "            lstm_output, _ = self.lstm(input_vector, hidden_state, lengths=x[\"decoder_lengths\"], enforce_sorted=False)\n",
    "\n",
    "            # transform into right shape\n",
    "            prediction = self.output_layer(lstm_output)\n",
    "            prediction = self.transform_output(prediction, target_scale=x[\"target_scale\"])\n",
    "\n",
    "            # predictions are not yet rescaled\n",
    "            return prediction\n",
    "\n",
    "        else:  # prediction mode\n",
    "            target_pos = self.target_positions\n",
    "\n",
    "            def decode_one(idx, lagged_targets, hidden_state):\n",
    "                x = input_vector[:, [idx]]\n",
    "                # overwrite at target positions\n",
    "                x[:, 0, target_pos] = lagged_targets[-1]  # take most recent target (i.e. lag=1)\n",
    "                lstm_output, hidden_state = self.lstm(x, hidden_state)\n",
    "                # transform into right shape\n",
    "                prediction = self.output_layer(lstm_output)[:, 0]  # take first timestep\n",
    "                return prediction, hidden_state\n",
    "\n",
    "            # make predictions which are fed into next step\n",
    "            output = self.decode_autoregressive(\n",
    "                decode_one,\n",
    "                first_target=input_vector[:, 0, target_pos],\n",
    "                first_hidden_state=hidden_state,\n",
    "                target_scale=x[\"target_scale\"],\n",
    "                n_decoder_steps=input_vector.size(1),\n",
    "            )\n",
    "\n",
    "            # predictions are already rescaled\n",
    "            return output\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        hidden_state = self.encode(x)  # encode to hidden state\n",
    "        output = self.decode(x, hidden_state)  # decode leveraging hidden state\n",
    "\n",
    "        return self.to_network_output(prediction=output)\n",
    "\n",
    "\n",
    "model = LSTMModel.from_dataset(dataset, n_layers=2, hidden_size=10)\n",
    "model.summarize(\"full\")\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "We used the :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.transform_output` method to apply the inverse transformation. It is also used under the hood for re-scaling/de-normalizing predictions and leverages the ``output_transformer`` to do so. The ``output_transformer`` is the ``target_normalizer`` as used in the dataset. When initializing the model from the dataset, it is automatically copied to the model.\n",
    "\n",
    "We can now check that both approaches deliver the same result in terms of prediction shape:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prediction shape in training: torch.Size([4, 2, 1])\n",
      "prediction shape in inference: torch.Size([4, 2, 1])\n"
     ]
    }
   ],
   "source": [
    "x, y = next(iter(dataloader))\n",
    "\n",
    "print(\n",
    "    \"prediction shape in training:\", model(x)[\"prediction\"].size()\n",
    ")  # batch_size x decoder time steps x 1 (1 for one target dimension)\n",
    "model.eval()  # set model into eval mode to use autoregressive prediction\n",
    "print(\"prediction shape in inference:\", model(x)[\"prediction\"].size())  # should be the same as in training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using and defining a custom/non-trivial metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use a different metric, simply pass it to the model when initializing it (preferably via the `from_dataset()` method). For example, to use mean absolute error with our `FullyConnectedModel` from the beginning of this tutorial, type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       MAE()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation=None)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_forecasting.metrics import MAE\n",
    "\n",
    "model = FullyConnectedModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2, loss=MAE())\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that some metrics might require a certain form of model prediction, e.g. quantile prediction assumes an output of shape `batch_size x n_decoder_timesteps x n_quantiles` instead of `batch_size x n_decoder_timesteps`. For the `FullyConnectedModel`, this means that we need to use a modified `FullyConnectedModule`network. Here `n_outputs` corresponds to the number of quantiles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 2, 7])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class FullyConnectedMultiOutputModule(nn.Module):\n",
    "    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, n_outputs: int):\n",
    "        super().__init__()\n",
    "\n",
    "        # input layer\n",
    "        module_list = [nn.Linear(input_size, hidden_size), nn.ReLU()]\n",
    "        # hidden layers\n",
    "        for _ in range(n_hidden_layers):\n",
    "            module_list.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])\n",
    "        # output layer\n",
    "        self.n_outputs = n_outputs\n",
    "        module_list.append(\n",
    "            nn.Linear(hidden_size, output_size * n_outputs)\n",
    "        )  # <<<<<<<< modified: replaced output_size with output_size * n_outputs\n",
    "\n",
    "        self.sequential = nn.Sequential(*module_list)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        # x of shape: batch_size x n_timesteps_in\n",
    "        # output of shape batch_size x n_timesteps_out\n",
    "        return self.sequential(x).reshape(x.size(0), -1, self.n_outputs)  # <<<<<<<< modified: added reshape\n",
    "\n",
    "\n",
    "# test that network works as intended\n",
    "network = FullyConnectedMultiOutputModule(input_size=5, output_size=2, hidden_size=10, n_hidden_layers=2, n_outputs=7)\n",
    "network(torch.rand(20, 5)).shape  # <<<<<<<<<< instead of shape (20, 2), returning additional dimension for quantiles"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Using the above-defined ``FullyConnectedMultiOutputModule``, we could create a new model and use :py:class:`~pytorch_forecasting.metrics.QuantileLoss`. Note that you would have to align ``n_outputs`` with the number of quantiles in the :py:class:`~pytorch_forecasting.metrics.QuantileLoss` class either manually or by making use of the `from_dataset()` method. If you want to switch back to a loss on a single output such as for :py:class:`~pytorch_forecasting.metrics.MAE`, simply set the ``n_ouputs=1`` as all PyTorch Forecasting metrics can handle the additional third dimension as long as it is of size 1."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implement a new metric"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "To implement a new metric, you simply need to inherit from the :py:class:`~pytorch_forecasting.metrics.MultiHorizonMetric` and define the loss function. The :py:class:`~pytorch_forecasting.metrics.MultiHorizonMetric` handles everything from weighting to masking values for you. E.g. the mean absolute error is implemented as"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_forecasting.metrics import MultiHorizonMetric\n",
    "\n",
    "\n",
    "class MAE(MultiHorizonMetric):\n",
    "    def loss(self, y_pred, target):\n",
    "        loss = (self.to_prediction(y_pred) - target).abs()\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "You might notice the :py:meth:`~pytorch_forecasting.metrics.Metric.to_prediction` method.  Generally speaking, it convertes ``y_pred`` to a point-prediction. By default, this means that it removes the third dimension from ``y_pred`` if there is one. For most metrics, this is exactly what you need.\n",
    "\n",
    "For custom :py:class:`~pytorch_forecasting.metrics.DistributionLoss` metrics, different methods need to be implemented.\n",
    "\n",
    ".. autoclass:: pytorch_forecasting.metrics.DistributionLoss\n",
    "   :members: map_x_to_distribution, rescale_parameters\n",
    "   :noindex:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model ouptut cannot be readily converted to prediction"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Sometimes a networks's ``forward()`` output does not trivially map to a prediction. For example, this is the case if you predict the parameters of a distribution as is the case for all classes deriving from :py:class:`~pytorch_forecasting.metrics.DistributionLoss`. In particular, this means that you need to handle training and prediction differently. Converting the parameters to predictions is typically implemented by the metric's ``to_prediction()`` method.\n",
    "\n",
    "We will study now the case of the :py:class:`~pytorch_forecasting.metrics.NormalDistributionLoss`. It requires us to predict the ``mean`` and the ``scale`` of the normal distribution. We can do so by leveraging our ``FullyConnectedMultiOutputModule`` class that we used for predicting multiple quantiles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "   | Name                 | Type                            | Params\n",
      "--------------------------------------------------------------------------\n",
      "0  | loss                 | NormalDistributionLoss          | 0     \n",
      "1  | logging_metrics      | ModuleList                      | 0     \n",
      "2  | network              | FullyConnectedMultiOutputModule | 324   \n",
      "3  | network.sequential   | Sequential                      | 324   \n",
      "4  | network.sequential.0 | Linear                          | 60    \n",
      "5  | network.sequential.1 | ReLU                            | 0     \n",
      "6  | network.sequential.2 | Linear                          | 110   \n",
      "7  | network.sequential.3 | ReLU                            | 0     \n",
      "8  | network.sequential.4 | Linear                          | 110   \n",
      "9  | network.sequential.5 | ReLU                            | 0     \n",
      "10 | network.sequential.6 | Linear                          | 44    \n",
      "--------------------------------------------------------------------------\n",
      "324       Trainable params\n",
      "0         Non-trainable params\n",
      "324       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"hidden_size\":                10\n",
       "\"input_size\":                 5\n",
       "\"learning_rate\":              0.001\n",
       "\"log_gradient_flow\":          False\n",
       "\"log_interval\":               -1\n",
       "\"log_val_interval\":           -1\n",
       "\"logging_metrics\":            ModuleList()\n",
       "\"loss\":                       SMAPE()\n",
       "\"monotone_constaints\":        {}\n",
       "\"n_hidden_layers\":            2\n",
       "\"optimizer\":                  ranger\n",
       "\"optimizer_params\":           None\n",
       "\"output_size\":                2\n",
       "\"output_transformer\":         GroupNormalizer(center=True, eps=1e-08, groups=[], method='standard',\n",
       "                scale_by_group=False, transformation=None)\n",
       "\"reduce_on_plateau_min_lr\":   1e-05\n",
       "\"reduce_on_plateau_patience\": 1000\n",
       "\"weight_decay\":               0.0"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from copy import copy\n",
    "\n",
    "from pytorch_forecasting.metrics import NormalDistributionLoss\n",
    "\n",
    "\n",
    "class FullyConnectedForDistributionLossModel(BaseModel):  # we inherit the `from_dataset` method\n",
    "    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, **kwargs):\n",
    "        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this\n",
    "        self.save_hyperparameters()\n",
    "        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this\n",
    "        super().__init__(**kwargs)\n",
    "        self.network = FullyConnectedMultiOutputModule(\n",
    "            input_size=self.hparams.input_size,\n",
    "            output_size=self.hparams.output_size,\n",
    "            hidden_size=self.hparams.hidden_size,\n",
    "            n_hidden_layers=self.hparams.n_hidden_layers,\n",
    "            n_outputs=2,  # <<<<<<<< we predict two outputs for mean and scale of the normal distribution\n",
    "        )\n",
    "        self.loss = NormalDistributionLoss()\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        new_kwargs = {\n",
    "            \"output_size\": dataset.max_prediction_length,\n",
    "            \"input_size\": dataset.max_encoder_length,\n",
    "        }\n",
    "        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset\n",
    "        # example for dataset validation\n",
    "        assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
    "        assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
    "        assert (\n",
    "            len(dataset.time_varying_known_categoricals) == 0\n",
    "            and len(dataset.time_varying_known_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_categoricals) == 0\n",
    "            and len(dataset.static_categoricals) == 0\n",
    "            and len(dataset.static_reals) == 0\n",
    "            and len(dataset.time_varying_unknown_reals) == 1\n",
    "            and dataset.time_varying_unknown_reals[0] == dataset.target\n",
    "        ), \"Only covariate should be the target in 'time_varying_unknown_reals'\"\n",
    "\n",
    "        return super().from_dataset(dataset, **new_kwargs)\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]:\n",
    "        # x is a batch generated based on the TimeSeriesDataset\n",
    "        network_input = x[\"encoder_cont\"].squeeze(-1)\n",
    "        prediction = self.network(network_input)  # shape batch_size x n_decoder_steps x 2\n",
    "        # we need to scale the parameters to real space\n",
    "        prediction = self.transform_output(\n",
    "            prediction=prediction,\n",
    "            target_scale=x[\"target_scale\"],\n",
    "        )\n",
    "        if n_samples is not None:\n",
    "            # sample from distribution\n",
    "            prediction = self.loss.sample(prediction, n_samples)\n",
    "        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.\n",
    "        return self.to_network_output(prediction=prediction)\n",
    "\n",
    "\n",
    "model = FullyConnectedForDistributionLossModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2)\n",
    "model.summarize(\"full\")\n",
    "model.hparams"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "You notice that not much changes. All the magic is implemented in the metric itself that knows how to re-scale the network output to \"parameters\" transform distribution \"parameters\" to \"predictions\" using the model's ``transform_output()`` method and the metric's ``to_prediction`` method under the hood, respectively.\n",
    "\n",
    "We can now test that the network works as expected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 2, 2, 2])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[\"decoder_lengths\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "parameter predition shape:  torch.Size([4, 2, 2])\n",
      "sample prediction shape:  torch.Size([4, 2, 200])\n"
     ]
    }
   ],
   "source": [
    "x, y = next(iter(dataloader))\n",
    "\n",
    "print(\"parameter predition shape: \", model(x)[\"prediction\"].size())\n",
    "model.eval()  # set model into eval mode for sampling\n",
    "print(\"sample prediction shape: \", model(x, n_samples=200)[\"prediction\"].size())"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "To run inference, you can still use the :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.predict()` method as additional arguments are passed to the metrics's ``to_quantiles()`` method with the ``mode_kwargs`` parameter, e.g. we can execute the following line to generate 100 traces and subsequently calculate quantiles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 2, 7])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(dataloader, mode=\"quantiles\", mode_kwargs=dict(n_samples=100)).shape"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "tags": []
   },
   "source": [
    "The returned quantiles are here determined by the quantiles defined in the loss function and can be modified by passing a list of quantiles to at initialization.\n",
    "\n",
    "Note that the sampling in the network's ``forward()`` method is not strictly necessary here. However, e.g. for stochastic, autogressive networks such as :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, predicting should be done by passing ``n_samples=100`` directly to the predict method. Samples should be either aggregated with ``mode_kwargs=dict(use_metric=False)`` (added automatically) or extracted directly with ``mode=(\"raw\", \"prediction\")`` (equivalent to ``mode=\"samples\"`` in DeepAR)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.loss.quantiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.2, 0.8]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NormalDistributionLoss(quantiles=[0.2, 0.8]).quantiles"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adding custom plotting and interpretation"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "PyTorch Forecasting supports plotting of predictions and interpretations. The figures can also be logged as part of monitoring training progress using tensorboard. Sometimes, the output of the network cannot be directly plotted together with the actually observed time series. In these cases (such as our ``FullyConnectedForDistributionLossModel`` from the previous section), we need to fix the plotting function. Further, sometimes we want to visualize certain properties of the network every other batch or after every epoch. It is easy to make this happen with PyTorch Forecasting and the `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html>`_ on which the :py:class:`~pytorch_forecasting.models.base_model.BaseModel` is based."
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "The :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.log_interval` property provides a log_interval that switches automatically between the hyperparameters ``log_interval`` or ``log_val_interval`` depending if the model is in training or validation mode. If it is larger than 0, logging is enabled and if ``batch_idx % log_interval == 0`` for a batch, logging for that batch is triggered. You can even set it to a number smaller than 1 leading to multiple logging events during a single batch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log often whenever an example prediction vs actuals plot is created"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "One of the easiest ways to log a figure regularly, is overriding the :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.plot_prediction` method, e.g. to add something to the generated plot.\n",
    "\n",
    "In the following example, we will add an additional line indicating attention to the figure logged:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def plot_prediction(\n",
    "    self,\n",
    "    x: Dict[str, torch.Tensor],\n",
    "    out: Dict[str, torch.Tensor],\n",
    "    idx: int,\n",
    "    plot_attention: bool = True,\n",
    "    add_loss_to_title: bool = False,\n",
    "    show_future_observed: bool = True,\n",
    "    ax=None,\n",
    ") -> plt.Figure:\n",
    "    \"\"\"\n",
    "    Plot actuals vs prediction and attention\n",
    "\n",
    "    Args:\n",
    "        x (Dict[str, torch.Tensor]): network input\n",
    "        out (Dict[str, torch.Tensor]): network output\n",
    "        idx (int): sample index\n",
    "        plot_attention: if to plot attention on secondary axis\n",
    "        add_loss_to_title: if to add loss to title. Default to False.\n",
    "        show_future_observed: if to show actuals for future. Defaults to True.\n",
    "        ax: matplotlib axes to plot on\n",
    "\n",
    "    Returns:\n",
    "        plt.Figure: matplotlib figure\n",
    "    \"\"\"\n",
    "    # plot prediction as normal\n",
    "    fig = super().plot_prediction(\n",
    "        x, out, idx=idx, add_loss_to_title=add_loss_to_title, show_future_observed=show_future_observed, ax=ax\n",
    "    )\n",
    "\n",
    "    # add attention on secondary axis\n",
    "    if plot_attention:\n",
    "        interpretation = self.interpret_output(out)\n",
    "        ax = fig.axes[0]\n",
    "        ax2 = ax.twinx()\n",
    "        ax2.set_ylabel(\"Attention\")\n",
    "        encoder_length = x[\"encoder_lengths\"][idx]\n",
    "        ax2.plot(\n",
    "            torch.arange(-encoder_length, 0),\n",
    "            interpretation[\"attention\"][idx, :encoder_length].detach().cpu(),\n",
    "            alpha=0.2,\n",
    "            color=\"k\",\n",
    "        )\n",
    "    fig.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "If you want to add a completely new figure, override the :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.log_prediction` method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log at the end of an epoch"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Logging at the end of an epoch is another common use case. You might want to calculate additional results in each step and then summarize them at the end of an epoch. Here, you can override the :py:meth:`~pytorch_forecasting.models.base_model.BaseModel.create_log` method to calculate additional results to summarize and the ``epoch_end()`` hook provided by PyTorch Lightning.\n",
    "\n",
    "In the example below, we first calculate some interpretation result (but only if logging is enabled) and add it to the ``log`` object for later summarization. In the ``epoch_end()`` hook we take the list of saved results, and\n",
    "use the ``log_interpretation()`` method (that is defined in the model elsewhere) to log a figure to the tensorboard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_forecasting.utils import detach\n",
    "\n",
    "\n",
    "def create_log(self, x, y, out, batch_idx, **kwargs):\n",
    "    # log standard\n",
    "    log = super().create_log(x, y, out, batch_idx, **kwargs)\n",
    "    # calculate interpretations etc for latter logging\n",
    "    if self.log_interval > 0:\n",
    "        interpretation = self.interpret_output(\n",
    "            detach(out),\n",
    "            reduction=\"sum\",\n",
    "            attention_prediction_horizon=0,  # attention only for first prediction horizon\n",
    "        )\n",
    "        log[\"interpretation\"] = interpretation\n",
    "    return log\n",
    "\n",
    "\n",
    "def epoch_end(self, outputs):\n",
    "    \"\"\"\n",
    "    Run at epoch end for training or validation\n",
    "    \"\"\"\n",
    "    if self.log_interval > 0:\n",
    "        self.log_interpretation(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log at the end of training"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "A common use case is to log the final embeddings at the end of training. You can easily achieve this by levering the PyTorch Lightning ``on_fit_end()`` model hook. Override that method to log the embeddings.\n",
    "\n",
    "The follow example assumes that there is a ``input_embeddings`` is a dictionary like object of embeddings that are being trained such as the :py:class:`~pytorch_forecasting.models.nn.embeddings.MultiEmbedding` class. Further a hyperparameter ``embedding_labels`` exists (as automatically required and created by the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def on_fit_end(self):\n",
    "    \"\"\"\n",
    "    run at the end of training\n",
    "    \"\"\"\n",
    "    if self.log_interval > 0:\n",
    "        for name, emb in self.input_embeddings.items():\n",
    "            labels = self.hparams.embedding_labels[name]\n",
    "            self.logger.experiment.add_embedding(\n",
    "                emb.weight.data.cpu(), metadata=labels, tag=name, global_step=self.global_step\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Minimal testing of models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing models is essential to quickly detect problems and iterate quickly. Some issues can be only identified after lengthy training but many problems show up after one or two batches. PyTorch Lightning, on which PyTorch Forecasting is built, makes it easy to set up such tests."
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Every model should be trainable with some minimal dataset. Here is how:\n",
    "\n",
    "#. Define a dataset that works with the model. If it takes long to create, you can save it to disk with the :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.save` method and load it with the :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.load` method when you want to run tests. In any case, create a reasonably small dataset.\n",
    "\n",
    "#. Initialize your model with ``log_interval=1`` to test logging of plots - in particular the `plot_prediction()` method.\n",
    "\n",
    "#. Define a `Pytorch Lightning Trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`_ and initialize it with ``fast_dev_run=True``. This ensures that not full epochs but just a couple of batches are passed through the training and validation steps.\n",
    "\n",
    "#. Train your model and check that it executes.\n",
    "\n",
    "As example, we marshall the ``FullyConnectedForDistributionLossModel`` defined earlier in this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).\n",
      "\n",
      "  | Name            | Type                            | Params\n",
      "--------------------------------------------------------------------\n",
      "0 | loss            | NormalDistributionLoss          | 0     \n",
      "1 | logging_metrics | ModuleList                      | 0     \n",
      "2 | network         | FullyConnectedMultiOutputModule | 324   \n",
      "--------------------------------------------------------------------\n",
      "324       Trainable params\n",
      "0         Non-trainable params\n",
      "324       Total params\n",
      "0.001     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:  50%|█████     | 1/2 [00:00<00:00, 13.82it/s, loss=0.583, v_num=]\n",
      "Validating: 0it [00:00, ?it/s]\u001b[A\n",
      "Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 12.06it/s, loss=0.583, v_num=, train_loss_step=0.583, train_loss_epoch=0.583, val_loss=0.383]\n",
      "Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 11.25it/s, loss=0.583, v_num=, train_loss_step=0.583, train_loss_epoch=0.583, val_loss=0.383]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAAEkCAYAAABt4jWqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+wklEQVR4nO3deXxU1f3/8ddnJvsGhIQACWGTkARkFxRQEZVNEPcN0bZfq7bS1rrUulStti71p7UurVKrVQFRcUPEXRAFQQMCsu9ZCIGwZSX7+f1xb8oYE8jKneXzfDzyyMzcO/d+7kDmPefcM+eKMQallFLKKS6nC1BKKRXYNIiUUko5SoNIKaWUozSIlFJKOSrI6QKUUsrbrVy5slNQUNALQH/0A3xL1ADrqqqqrhs6dOi+2gc1iJRS6jiCgoJe6Ny5c1p8fPwhl8ulQ42bqaamRvLz89Pz8vJeAM6vfVyTXSmljq9/fHx8oYZQy7hcLhMfH1+A1bI8+rhD9SillC9xaQi1Dvt1/FH2aBAppZSP2rx5c0ifPn36OV1HXcOHD++7ZMmSiMaur0GklFLqfyorK0/4PjWIlFLKR9x///0Jffr06denT59+DzzwQCeAqqoqLrrooh4pKSnpEyZM6FVUVOQC+PWvf53Yu3fvfikpKenXX399EkBubm7Q+PHje/fv3z+tf//+aZ988kkkwC233NL1yiuv7D5q1Kg+F110Uc8BAwakZmRkhNXud/jw4X2/+uqriMLCQtell17ao3///mlpaWnps2bNag9QXFwskydP7pWSkpJ+3nnn9SorK5OmHJeOmlNKqSa4fd6ablvyihrd7dQYKZ2jSx+7ZGD2sdb56quvIubMmdNx5cqVG40xDB06NO3ss88u2rVrV9jzzz+/a9y4cSWXXnppj8ceeyz+pptu2r9w4cIOO3bsWOdyudi/f78b4IYbbuh2yy237B0/fnzx1q1bQ8aPH99nx44d6wHWrl0bsWLFik1RUVHmz3/+c6fZs2fHDhs2LDczMzN43759waeffnrpjBkzEs8666zCN998c9f+/fvdw4YNSzv//PMLn3jiifjw8PCaLVu2bFixYkX4qFGj0pty/NoiUkopH7B48eKoSZMmHY6Jialp165dzXnnnXdo0aJF0Z07d64YN25cCcD06dMPLFu2LCo2NrY6NDS05oorruj+8ssvt4+KiqoBWLp0aczvfve75NTU1PQpU6acVFxc7D506JALYMKECYejoqIMwDXXXHNo/vz5HQBeeeWVDlOmTDlk1xDz97//vUtqamr66NGj+5aXl8u2bdtCvv7666jp06cfABgxYsSRlJSU0qYcm7aIlFKqCY7XcmkrDV0pQUR+cj84OJjVq1dvnD9/fszcuXM7/Otf/+q0fPnyLcYYMjIyNtYGjqfIyMia2ts9e/asbN++fdWKFSvC33777djnn38+s7aGefPmbRs4cGD58epoCm0RKaWUDxg7dmzxwoUL2xcVFbkKCwtdCxcu7HDWWWcV7dmzJ+Szzz6LBJgzZ07syJEjiwsKClwHDx50X3755QXPPfdc9saNGyMARo8eXfjoo492qt3msmXLwhva3yWXXHLwoYce6lxUVOQePnz4EYCzzjqr8PHHH0+oqbEya+nSpeH2dotnzZoVC/Ddd9+FbdmypUldlxpESinlA0aPHl161VVXHRgyZEja0KFD06ZPn54fFxdX3atXr7IXX3yxY0pKSvqhQ4eCbrvttvzDhw+7J0yY0CclJSX99NNP7/uXv/wlG2DmzJnZq1atikxJSUnv3bt3v2eeeSa+of1dffXVhz744IPYqVOnHqx97JFHHsmtqqqS1NTU9D59+vS75557EgFuu+22fSUlJe6UlJT0hx56qPPJJ59c0pRjE70wnlJKHduaNWt2DRw4cL/TdfiLNWvWxA0cOLBH7X1tESmllHKUBpFSSilHaRAppZRylAaRUkopR2kQKURkl4ic48B+RUQeFZED9s/f5BhfRhCRCBH5p4jsF5ECEVnisay9iLwsIvvsn/vrPPdBEflBRKrqWSYicreIZIlIoYjMFZEYj+WhIvKivSxPRG6p8/xBIrJSRErt34M8ll0hIpvtevfZNXpuu7jOT7WIPG0vO1VEPhWRgyKSLyJvikgXj+feLyKVdZ7fy2N5DxFZZNe1yfPfWETGiEhNnede67H8MhFZZj93cT3/FmNFZJX9muwQkevrvJ5/EZHd9nEvFhGvm5hTeQ8NIuWk64ELgIHAAGAycMMx1p8JxAJp9u/feyz7OxAB9ACGA9NF5Ocey7cBfwA+qGe71wDTgVFAVyAceNpj+f1AH6A7cBbwBxGZACAiIcB7wCygA/Ay8J79OMBSYJQxph3QC+tL5H+p3bAxJqr2B0gAjgBv2os72Mfcw953EfBSndpf99yGMWaHx7LXgO+BjsDdwDwR8Ryum1vnuS97LDsIPAk8UvfFEpFg4B3geaAdcDnwhIgMtFe5FPgFcDrWv9M3wKt1t6NULQ0i1SC7JfCkiOTaP0+KSKi9LE5EFojIYfsT+1ci4rKX3WF/Gi6yWwNnN7CLa4HHjTE5xpjdwOPAzxqopS/WFR2vN8bkG2OqjTErPVaZAvzNGFNqjNkF/AfrzRAAY8zLxpgPsd7M65oC/McYk22MKQYeBS4Xkdov5V0DPGiMOWSM2Qj826POMVjh8qQxptwY8xQgwFh7v9nGGM9hv9XASQ28HpcA+4Cv7Od+aIx50xhTaIwpBZ7BCsvjEpEUYAhwnzHmiDHmLeAH4OLGPN8Y85kx5g0gt57FsUAM8KqxfAdsBGrnF+sJfG2M2WGMqcYK6SbNPaba1oIFC6LPOuuskwBmz57d7q677urc0Lr79+93P/LIIw1+36ght9xyS9d77703oTHrahCpY7kbOBUYhNVqGQ7cYy+7FcgB4rE+yd8FGDswZgCnGGOigfHArga23w9Y43F/jf1YfUYAmcCf7a65H0Sk7puq1Lndn8aRep4bCvQRkQ5YraSG6uwHrDU//kLeWs/jEJHRIlKAFYIXY7U06nMt8EqdbXk6A1hf57Ep9geB9SLyK4/H+wE7jDGewVv39e0kIntFZKeI/F1EIhvY748YY/ZitbZ+LiJuETkNq8X2tb3KXOAkEUmxW0/XAh81ZtuqZaqqqpr8nGnTphU89NBDeQ0tP3DggPs///lPp4aWtwYNInUs04AHjDH7jDH5wJ+xurAAKoEuQHdjTKUx5iv7DbQa6008XUSCjTG7jDHbG9h+FFDgcb8AiBKp9zxRElawFGAFwwzgZRFJs5d/BPxRRKJF5CSs1lBjpxn5ELjOPqfSDrjDfjzCrrG2Ns86oxs4hrrLMcZ8bXfNJQGPUU8wi0gycCZW195PiMgA4F7gdo+H38DqpowHfgncKyJXNrKuTVgfMLpgtd6GAk/Ut+8GvGbXU47VgrvbGFM7B9se+7HNWF2Nl/LjblTVDJs3bw7p2bNnv7qXfEhMTDz5tttu6zJ06NC+L774Yoe33347ZtCgQanp6elpEydO7FVQUOACmDdvXkzPnj37DR06tO+8efPa1273qaee6njNNdckA2RnZwede+65vfv27Zvet2/f9E8//TTy1ltvTcrOzg5NTU1Nv+GGG5IA/vSnPyX0798/LSUlJf33v/9919pt3XHHHZ179OjRf+TIkSlbt24Nbeyx6aSn6li6YrVCamXaj4H1hno/8ImdGzONMY8YY7aJyM32sn4i8jFwizGmvi6eYqwunloxQHEDLYIjWOH3F2NMFfCliCwCxmF1C/0W67zOVuAA1hvllfVspz4vAt2AxVh/E49jddfl2DXW1lbmcbu2pVH3GOou/x9jzG4R+QirxTCkzuJrsLqzdtZ9nh2sHwK/M8Z85bG9DR6rLRORf2B17712vLqMMXlA7afgnSJSe/7sWOfoautJBV4HLgQ+xTp/tkBEco0xHwD3AadgvaZ5wNXAFyLSz+5i9G3v3tSNfRta9TIQdEov5YJnjzuZan2XfAAICwurWbly5eY9e/YETZkypfeSJUu2xMTE1Nx9992dH3zwwYQHHnggb8aMGT0+/fTTzf369SufPHlyr/q2f+ONNyaffvrpRffee+/2qqoqCgoK3I8//njO5MmTwzdt2rQB4O23347Ztm1b2Nq1azcaYzjnnHNO+vDDD6OioqJq3nnnndgffvhhQ2VlJYMGDUofPHhwo/69tUWkjiUXq8ulVrL9GMaYImPMrcaYXlhv2rfUngsyxswxxoy2n2uwzrnUZz1Wl1+tgfy066nW2mMVaow5aIyZZozpbIzph/V/+9tjHt3R59YYY+4zxvQwxiTZNewGdhtjDmF9wm+ozvXAgDqtuAHHOI4goHc9j19DPa0hEekOfIZ1jup4J/wNR7sY1wO9RCTaY/mxXl/P5x5Pf2CzMeZj+7XbjBViEz3287p97q/KGPNfrIEXep6oheq75ANYl20AWLx4ceT27dvDhg8fnpqampo+d+7cjllZWSGrV68OS0pKKj/55JPLXS4X06ZNO1Df9pctWxZ9++235wMEBQXRsWPH6rrrfPTRRzFLliyJSU9PT+/Xr1/69u3bwzZt2hS2aNGiqEmTJh2Ojo6uiY2NrRk3btzhxh6XtohUrWARCfO4X4X1yfoeEfkO643qXqwTz4jIZKzune1AIVaXXLV9jigRa7RYGVZLpqEPPK9gBdhCe/u38uPRap6WAFnAnSLyMNY5ozHYXVUi0hs4bP+MwxqRd2btk+1zFW67liD7WCuNMdUiEov1RrkDq6vrCawuydpp8V+xX4cMrPNhvwRqR+Qtto/9tyLynL0M4At7v9OwuqmysYL8r8DnngcmIiPt1+zNOo8n2tt51hjzXN0XRESm2q/LYawWyG+xztVhjNkiIquB+0TkHqyQGIA9WEFExtjHm43VZfgI1ui/2m27gWCs9wiX/XpVG2MqsUbi9RGRscAirNGAkzn6geM74FIRmQvkY3XxBmONXPR9jWi5tJW6vda196Ojo2vAukzD6NGjC99///0ftayXLVsWXn+Pd9MZY7j55pv33H777T+ae++BBx7o1Nx9aItI1VqIFRq1P/djDTPOwGqN/ACs4ujQ4z5Yn9SLsYbn/tMYsxjr/NAjwH6sbplO2G+O9XgeeN/e9jqsT9XP1y60T8BPA7DfAKcCk7DOdfwbuMYYs8lefai9nSLgYWCaMcbz0/+/7eO6EmsQxhGOnu+Ks4+/BKsL7EVjzEyP596HFbiZwJfAY8aYj+y6KrCGoF+DFQi/AC6wHwerFbDMfp2WYp03+SU/di3wdp2BBQDXYb3J3+fxXZ9ij+VXYL25F2GF5aN1hmBfAQwDDmH9m1xin+sDq2vwG/uYl2G9/r/1eO50+zX6F9Yw7CP2a4h9zu8XwFNYH0K+BN7CGqkIViCtAVbbr8nvgYuNMYdRLVLfJR88l48ZM6YkIyMjat26daEARUVFrrVr14YOGjSoLCcnJ2T9+vWhAHPnzo2tb/ujRo0qqu3uq6qq4uDBg6527dpVl5SU/C8rJk6cWPjqq6/G1Z572rlzZ/Du3buDxo4dW/zBBx+0Ly4ulkOHDrk+/fTT9o09Lp19WymljsMbZt/evHlzyKRJk/qMGDGiKCMjI6pnz57l8+bN25mamtovIyNjY5cuXaoA5s+fH33XXXclVVRUCMB99923e9q0aQXz5s2Luf3227vFxsZWjRgxonjjxo3hixYt2vbUU091zMjIiHzllVeysrOzg372s591z87ODnW5XDzzzDOZ55xzTsmUKVN6btq0KWLs2LEFzz//fM6DDz7Y6dVXX40DiIiIqJk9e/bOfv36ld9xxx2dX3/99bjExMTyrl27VqalpR154IEH9tY9lrqzb2sQKaXUcXhLEE2ePLnP1q1bGzrP5zP0MhBKKaW8igaRUkr5gL59+1b4Q2uoPhpESimlHOXVw7fj4uJMjx49nC5DKRXgHn30UdavX9+9tYZAt6Xy8vKqwYMHrzn+ms6oqakRoMbzMa8Ooh49epCRkeF0GUqpALdz506io6Pp2LHjT77L423WrVtXcfy1nFFTUyP5+fntsL4u8D9eHURKKeUNkpKSyMnJIT8///grOywvLy+ouro6zuk6GlADrKuqqrrO80ENIqWUOo7g4GB69uzpdBmNkp6e/oMxZpjTdTSFDlZQSinlKA0ipZRSjtIgUkop5SgNIqWUUo7SIFJKKeUoDSKllFKO0uHbSinVWowBU+Pxuwao89iP7tf3WA1EJYA7cN6eA+dIlVKq5njh4PkYjVinTqi0FlNz/HX8iAaRUso/lBVAeTE/CZn/hYpee81baRAppfxDVQVUHnG6CtUMOlhBKaWUozSIlFJKOUqDSCmllKM0iJRSfqGovNrpElQzaRAppXzel1vyGf30aj7aUuR0KaoZNIiUUj6vZ8dIesSGceP8XO77fC9lVYH1PRxfp0GklPJ5yR0jePPaNK4b2oGXvz/MxXOy2HXIa6+YrerQIFJK+YUQt4t7zurECxcmklNYyeRXM5m/qdDpslQjaBAppfzKOb2jWHhND/rGhfDbBXu485M8yiq1q86baRAppfxOYkwwcy9P5tcjYnltbQFTZ2ey7UC502WpBrRKEInIBBHZLCLbROSPx1jvFBGpFpFLWmO/SinVkGC38IfT43n54iT2l1Qz5dVM5q0rcLosVY8WB5GIuIFngYlAOnCliKQ3sN6jwMct3adSSjXWmT0jWXhtDwZ2CeO2j/K4ZeEeSiq0q86btEaLaDiwzRizwxhTAcwFptaz3m+At4B9rbBPpZRqtISoIGZf2o2bR3bknQ2FTJmVycb8MqfLUrbWCKJEINvjfo792P+ISCJwIfBcK+xPKaWazO0Sbh4Zx+zLulFcXs3UWVnMXnMYo5eHcFxrBJHU81jdf9kngTuMMcedg0NErheRDBHJyM/Pb4XylFLqqJHJESy8pgcjuoVz96d7+c2CPTo9kMNaI4hygG4e95OA3DrrDAPmisgu4BLgnyJyQX0bM8bMNMYMM8YMi4+Pb4XylFLqx+Iig3j54iT+cHocH24pYvKrmfyQp111TmmNIPoO6CMiPUUkBLgCmO+5gjGmpzGmhzGmBzAP+LUx5t1W2PdPHKmo5rGPN7F02/622LxSyk+4RPj1iI68fkU3KqoNF83J5KVVh7SrzgEtDiJjTBUwA2s03EbgDWPMehG5UURubOn2m8rlgvfX7OGB9zdQVa0jY5RSxzYs0eqqO7NnJH/+Yh83vJdLQZl21Z1IrfI9ImPMQmNMijGmtzHmr/ZjzxljfjI4wRjzM2PMvNbYb31Cg9zcOTGVzXuLeD0j+/hPUEoFvA7hbv59QSL3jIln0Y5iJr2yi1W5etnxE8UvZ1aY0L8zw3vG8sQnWygsq3S6HKWUDxARrhsWy5tXJuMS4bK5WTz/7UFqtKuuzfllEIkI905O52BpBc9+sc3pcpRSPmRQl3AWTO/Oub2jeHhJPv/39m4OllY5XZZf88sgAuif2I6LhyTx0tJdZB4ocbocpZQPaRfm5p/nd+XBszuxNKuUia9ksiK71Omy/JbfBhHA7eP7EuQWHl64yelSlFI+RkSYPrgD70xLJiJYuPKNbJ7+5gDVNdpV19r8OogSYsL41Zm9+Wh9Hst3HHC6HKWUD+rXKYz3p/dgSmo0jy/dzzXzcthXol11rcmvgwjgl2f0omu7MB5csEE/ySilmiUqxMWTk7rwt/GdWZl7hEkv7+LrTO3yby1+H0RhwW7umJjK+txC3lqV43Q5SikfJSJcdnI75l/dnQ7hbqa/mcPjX++nSj/gtpjfBxHA+QO7Mji5PY99vJnicm1SK6WaLyUulPemdefS/u14evkBrnojmz1F+jWRlgiIIKodzp1fVM5zi7c7XY5SysdFhLj424TO/H1SZ9btLWPSK5ks2lHsdFk+KyCCCGBwcgcuGNSVmV/tIOeQDsNUSrXchenteH96DzpHBfHzt3fz8Jf7qKzWrrqmCpggAvjDhFRcAo9+tNnpUpRSfqJ3bAjvTEvm6oHtef67Q1w2N4vsAu2qa4qACqKu7cO5/ozevL8ml5WZB50uRynlJ8KCXPzl3ASendKVbQcqOO+VXXy0tajxG6iphuoKqDwCFcVgAmvC5oAKIoAbz+xFQkwoDyzYSI2OdlFKtaLz+kbzwTXd6dE+mBvfy+X+z3IpP1JqhUtZARw5BCX5ULwXCvdAQTYcyoSCHOt+8T4oOQA1gTWoKuCCKCIkiNvHp7Im+zDz19S9fp9SSjXA1EB1JVSVQUUplBfBkcNQetAKl6I8KMwlmb3Mmwi/6Ofmv6uLuHhuDrv25FvrlhVaz60ss1pANYHV8mlIkNMFOOGiwYm8vGwXj360ifH9OhMe4na6JKWUE2qqrYCpqbJ+e96uqQFTe7samjALd4hbuPfUYE7r4uK2JZVMfreCh0cHM6WXvtfUJ+BaRAAul/CnyensKShj5pIdTpejlGpNNVU/Pt9SVmh3ie23ur4K91hdYYez7C6xXLtLbL/VuikrhPJiqCyFqgqormpSCHk6t7ubhReGktJB+M2iSu78upKyKj0lUFdABhHA8J6xnHdyF577cjt5BXqteqV8XulB+3zL7h+fbzlyyO4SK7HCqbqiyS2clkiMEl4/L4RfDXDz2uZqLphfwbbD2iXnKWCDCOCPE1OprjH87WOdnVspn+fFF7ALdgl3nBLMf8cHs++IYcp7Fby1VS9HXiugg6hbbAT/d3pP3l61mzXZh50uRynl58YkufnwglAGxAm3Lqnk1iUVlFZ6b4CeKAEdRAC/HtObuKgQHlywAePFn6iUUv4hIVKYMzGE3w528/bWGqa8V8Gmg4HdVRfwQRQdFsxt4/qSkXmID37Y43Q5SqkA4HYJtwwJZvbEYAorDFPnV/DapqqA/TAc8EEEcOmwbqR1ieGRDzdRVqn9tkqpE2NkV2tU3fAEF3cureK3iyspqgi8MNIgwvp08qfz0sg5dIQXl+50uhylVACJDxdenhDM7cOCWLizhsnvVrAutwnTA/kBDSLbyJPiODc9gWe/2Ma+Ih3OrZQ6cVwi3DQwiLmTQqioMVz9ylpKAujaaRpEHu6alEZFdQ1PfLLF6VKUUgHolM4uFl4Qyj8vSycyNHAmvtEg8tAzLpJrT+vB6xnZrM8tcLocpVQA6hAmjOzVwekyTigNojp+c3Yf2ocH63BupZQ6QTSI6mgXHswt56awfMdBPtmw1+lylFLK72kQ1ePK4cn06RTFQws3Ul6lw7mVUqottUoQicgEEdksIttE5I/1LJ8mImvtn2UiMrA19ttWgtwu7pmcTuaBUl5Zlul0OUop5ddaHEQi4gaeBSYC6cCVIpJeZ7WdwJnGmAHAg8DMlu63rZ2ZEs+YvvE89cVWDhSXO12OUq3OGMOf31/PysyDTpeiAlxrtIiGA9uMMTuMMRXAXGCq5wrGmGXGmEP23eVAUivst83dc14apRXVPPnZVqdLUarVrco6xEtLd7Flb7HTpagA1xpBlAhke9zPsR9ryP8BHza0UESuF5EMEcnIz89vhfKa76RO0Vw9IpnZKzLZsjewvums/N+s5VlEhwYxdVBXp0tRAa41gkjqeazecc8ichZWEN3R0MaMMTONMcOMMcPi4+NbobyWufmcFKJCg3Q4t/IrB0sq+OCHPVw0JJGIkMD54qTyTq0RRDlAN4/7SUBu3ZVEZADwAjDVGHOgFfZ7QnSIDOF356Tw1db9LN7sbAtNqdYyb2U2FVU1TDu1u9OlKNUqQfQd0EdEeopICHAFMN9zBRFJBt4GphtjfG7+nOmndqdnXCR/+WADldWBfd0Q5ftqagyzV2QxvEcsKQnRTpejVMuDyBhTBcwAPgY2Am8YY9aLyI0icqO92r1AR+CfIrJaRDJaut8TKSTIxd2T0tieX8Ls5TqcW/m2pdv3k3mglGmnJjtdilIAtErnsDFmIbCwzmPPedy+DriuNfbllLPTOjHqpI48+flWLhicSPuIEKdLUqpZZi3PpGNkCBP6d3a6FKUAnVmh0USEe85Lp/BIJf/4XIdzK9+UV1DGZxv3cemwboQGuZ0uR3mqqYbqCqg8AiawTgHocJkmSOsSw+WnJPPqN5lcfWp3esdHOV2SUk0y97ssaozhquHaLXfC1VRbAVNTZd+2f9dUg6kGz1G5NYFzLSLQFlGT3XJuCmHBbh5euNHpUpRqkqrqGuZ+m80ZfeJJ7hjhdDn+x9RAdaXVoikvhrICKD0IxfugaA8U74WSfDhyCMoLoaIUqsqt0Anwr4ZoEDVRfHQoM8aexGcb9/H11v1Ol6NUo32+aR95hWVMG6GtoWapDZqqMqgo+WnQFOXVCZoSa10NmuPSIGqGn4/qQbfYcB5csIEqHc6tfMSs5Zl0aRfG2NROTpfinYz5adAcOWiFS1He0aApPWgt06BpNRpEzRAa5OauiWls3lvE6xnZx3+CUg7LPFDCV1v3c+XwZILcAfpn/5OgKawTNHt+GjSVZdZzAmzwwImmgxWaaUL/zgzvGcsTn2xhysCuxIQFO12SUg2asyILt0u4/JRux1/ZVxlzdCBAvQMCNEy8VYB+NGo5EeFP56VzsLSCZ7/Y5nQ5SjWorLKaNzKyGZeeQEJMmNPltJ0jh6zzNaUH7BZNsbZofIQGUQucnNSOi4ck8dLSXWQeKHG6HKXq9dG6PA6VVjJthM4rp7yTBlEL3T6+L0Fu4eGFm5wuRal6zVqeSc+4SEb27uh0KUrVS4OohRJiwvjVmb35aH0ey3f4zKTiKkBsyiskI/MQ00Yk43LVd8UWpZynQdQKfnlGL7q2C+PBBRuortFhnMp7zF6eRUiQi4uH+MRFkVWA0iBqBWHBbu6YmMr63ELeWpXjdDlKAVBSXsU73+9m8oAudIjUSXqV99IgaiXnD+zK4OT2PPbxZkrKA2ueKOWd3ludS3F5lQ5SUF5Pg6iViAh/mpxOflE5/1q83elyVIAzxjBreSZpXWIYktze6XKUOiYNolY0JLkDUwd15d9f7SDnUKnT5agA9n32YTbsKeTqU5MR0UEKyrtpELWyOyakIgKPfrTZ6VJUAJu9PIvIEDdTByU6XYpSx6VB1Mq6tg/n+tN78f6aXFZmHnS6HBWADpdWsGBtLhcOSSQqVGfxUt5Pg6gN3HBmbxJiQnlgwUZqdDi3OsHmrcyhvKpGBykon6FB1AYiQ4O4fXwqa7IPM39NrtPlqABijGHOiiyGdu9AWpcYp8tRqlE0iNrIRYMTOTmxHY9+tIkjFdVOl6MCxLLtB9ixv4SrT9WL3ynfoUHURlwuazj3noIyZi7Z4XQ5KkDMXpFJh4hgJvbv4nQpSjWaBlEbGt4zlkknd+a5L7eTV1DmdDnKz+0rLOOT9Xu5dFg3woLdTpejVKNpELWxOyemUV1j+NvHOju3aluvf5dNVY3hyuHaLad8iwZRG+sWG8EvRvfk7VW7WZtz2OlylJ+qrjG89m0Wp/eJo2dcpNPlKNUkGkQnwE1n9SYuKoQH3t+AMTqcW7W+LzbtI7egTIdsK5+kQXQCRIcFc+u4vmRkHmLhD3lOl6P80OwVmSTEhHJOWienS1GqyTSITpDLhnUjtXM0D3+4kbJKHc6tWk/2wVK+3JLPFackE+TWP2nle1rlf62ITBCRzSKyTUT+WM9yEZGn7OVrRWRIa+zXl7hdwr2T08k5dIQXl+50uhzlR+Z8m4VLhCuGd3O6FKWapcVBJCJu4FlgIpAOXCki6XVWmwj0sX+uB/7V0v36opEnxXFOWgLPfrGNfUU6nFu1XHlVNW98l83ZqZ3o0i7c6XJUaygvgry1TldxQrXGjIjDgW3GmB0AIjIXmAps8FhnKvCKsc7ULxeR9iLSxRizpxX271PuPi+NcX//kic+2cIjFw9wuhzl4z5al8eBkgquPlUHKficimIo2A1Fu6Eoz/op2Wc9Hh4LdwROz0lrBFEikO1xPwcY0Yh1EoGAC6KecZFcc1oPXly6k+mndadf13ZOl6R82OwVWXTvGMHok+KcLkU1pKIECndDYS4U7fEInKKj67iCIaoTxPWB6C6Qdj4YAwFyLanWCKL6Xqm6Y5Qbs461osj1WN13JCf75xfzfju2D2+vyuEvCzYy55cj9MJlqlm27C3i250HuXNiKi6X/h9yXOURj8DJhaK9UJz348CRIIiKh469rcCJ7gIxiRAZD+JxpiRlfMCEELROEOUAnmdJk4C6U043Zh0AjDEzgZkAw4YN88sv3bSLCOb356Zw73vr+XTDXsb16+x0ScoHzVmRRYjbxSVDk5wuJbBUHrGCptCjS614L5QXHF1H3Fa4xPaG6ASI7moFTlSnHweOAloniL4D+ohIT2A3cAVwVZ115gMz7PNHI4CCQDw/5Omq4cm8+k0mDy3cyJi+nQgJ0v+cqvFKK6p4a2UOk07uTMeoUKfL8U9V5VbrpnA3FO85Gjhlh4+uI26I6AgdukNU56OBE52ggdMELQ4iY0yViMwAPgbcwIvGmPUicqO9/DlgITAJ2AaUAj9v6X59XZDbxd3npfGzl77jlW92cd3pvZwuSfmQ+atzKSqv0kEKraGqwm7heHSpleyFI55XWHZBZEdo1w2SToEou0stKgHcehXclmqVV9AYsxArbDwfe87jtgFuao19+ZMxfTsxpm88//h8KxcOTtRPtqrRZq/Iom9CNEO7d3C6FN9RXekROB4tnCMHPFYSq4UT3RW6DvboUuusgdOG9JV12D3npTH+ya948rOtPHhBf6fLUT5gTfZhfthdwINT++lAl/pUV1mDBAp32y0cO3BKD/CjMVLhHa0WTZeBduB0tQYPuIMdKz1QaRA57KRO0Uwbkcys5ZlMP607KQnRTpekvNzsFZlEhLi5YHCi06V4j13L4PP74dAOKDkA1BxdFh4LkQmQ0N8jcLpCUIhT1ao6NIi8wM3npPDu97t5cMEGXvnFcP2UqxpUUFrJ/DW5XDg4iegw/eT+PxVFcHC71a3WKf3osOjorhCkXd7eToPIC8RGhvDbs/vwlw82snhzPmel6gzKqn5vrcqhrLKGaSP88zt2zZYyHi591Qoj5XN0fKGXuOa0HvSMi+QvH2ygsrrm+E9QAccYw+wVmQzq1p7+iTojh/IfGkReIiTIxV2T0tieX8KcFVlOl6O80PIdB9meX6JDtpXf0SDyIuekdWLUSR35+2dbKCitdLoc5WVmrcikXXgwkwd0cboUpVqVBpEXERHuOS+dwiOV/OPzrU6Xo7xIflE5H6/L45KhSYQFu50uR6lWpUHkZdK6xHD5Kcm88s0utucXO12O8hJvZGRTVWO4SgcpKD+kQeSFbjk3hbBgNw8v3Oh0KcoLVNcY5qzIYmTvjvSOj3K6HKVanQaRF4qPDuWms07is437+HrrfqfLUQ77css+dh8+ooMUlN/SIPJSPx/Vg26x4Ty4QIdzB7pZy7OIjw7l3PQEp0tRqk1oEHmpsGA3d09KY/PeIiY8uYRPN+zFmjtWBZKcQ6Us2ryPK07pRrBb/1yVf9L/2V5sQv8u/PuaYRjgl69kcPnM5azJPux0WeoEeu3bLAS4YrgOUlD+S4PIy52bnsDHN5/Bgxf0Z0d+MVOfXcqMOavIOlDqdGmqjVVU1fD6dzmMTe1EYvtwp8tRqs1oEPmAYLeL6ad2Z/HtZ/HbsSfx+cZ9nP3EYh5csIHDpRVOl6fayCcb8thfXM40HaSg/JwGkQ+JCg3ilnF9WXz7GC4anMRLS3dyxt8W8fyX2ymrrHa6PNXKZi3PJKlDOGf0iXe6FKXalAaRD0qICePRSwbw4e/OYGj3Djz84SbOfvxL3v1+NzU1OqDBH2zbV8zyHQe5akQybpdeFkT5Nw0iH9a3czQv/Xw4c64bQfuIYG5+fTXnP/s1y7bpd4983ewVmQS7hcuGdXO6FKXanAaRHxh5UhzvzxjN3y8fyKGSSq56YQU/f+lbtuwtcro01QxHKqp5a2UOE/p3IS5KL+qm/J8GkZ9wuYQLByfx+a1ncufEVDIyDzHhySX88a217C0sc7o81QTvr82lsKyKq3VeORUgNIj8TFiwmxvO7M2S28/iZyN78taqHMY8tpgnPtlMcXmV0+WpRpi9PJM+naIY3jPW6VKUOiE0iPxUh8gQ7p2Szue3jOHstE489cU2xjy2mFnLM6nSKYO81g85BazJKWDaiGREdJCCCgwaRH4uuWMEz1w1hHdvGkWvuEjueXcd455cwifr83TKIC80e0Um4cFuLhqa5HQpygnigqBQkMC65pQGUYAY1K09r99wKv++ZhgCXP/qSi5/fjmrdcogr1FYVsl7q3M5f2BXYsKCnS5HnQjuEAiJhPAOENUJojtDREdwBVYQBTldgDpxRIRz0xM4q288c7/L5snPtnDBs0uZPKALfxifSnLHCKdLDGjvrNrNkcpqvdyDv3IFWcHjDrZ+u4JAu18BDaKAFOR2cfWp3blgcCIzl+zg30t28PH6PKaf2oPfjD2JDpEhTpcYcIwxzFqeyYCkdpyc1M7pclRLudxW4Lg8gkdDp0HaNRfAokKDuOXcFBbfPoaLhyTx32U7OeOxRTynUwadcN/tOsTWfcVcPUJbQz6n9rxOaDRExEJUgvUTHguhUfY5Hw2hY9EgUiTEhPHIxQP46OYzGNa9A4/YUwa9832OThl0gsxankl0WBBTBnZ1uhR1LCINn9cJjYagsIA7v9MaWhREIhIrIp+KyFb7d4d61ukmIotEZKOIrBeR37Vkn6rtpCQcnTKoQ2Qwv399jU4ZdALsLy7nw3V7uHhIEuEh+ibmVVxBEBwBYe0gMh6iu0BknHU/ONxarlqspS2iPwKfG2P6AJ/b9+uqAm41xqQBpwI3iUh6C/er2tDIk+KYf9Nonrx80P+mDPrZS9+yOU+nDGoLb2bkUFltuPpUnUnBUS43BIdBaIzVwonuYrV4wttbLSC3jmRsKy0NoqnAy/btl4EL6q5gjNljjFll3y4CNgKJLdyvamMul3DB4EQ+v/VM7pqUyqrMQ0z8xxLumLeWvAKdMqi11NQY5nybyYiesZzUKdrpcgKHntfxKi1tVyYYY/aAFTgi0ulYK4tID2AwsOIY61wPXA+QnKyfEJ0WFuzm+jN6c+nQbjyzaBuvfLOL99bs5pen9+KGM3sTFapdEy2xZGs+2QeP8IfxqU6X4r9EfjqCTc/jeJXjvouIyGdA53oW3d2UHYlIFPAWcLMxprCh9YwxM4GZAMOGDdMz5V6iQ2QIf5qczrWn9eCxTzbz9BfbeO3bLH53TgpXnNKNYLeOe2mOWcuziIsKYXy/+v7EVLO4gz0CJ1i71HzAcYPIGHNOQ8tEZK+IdLFbQ12AfQ2sF4wVQrONMW83u1rluOSOETx95WD+b3RPHlq4kT+9u46Xlu7kjgmpjEtP0PnRmiD38BG+2LSXG8/sTUiQBnmLhcVY53X0/6DPaen//vnAtfbta4H36q4g1jvTf4CNxpgnWrg/5SUGdWvP69cfnTLoBnvKoO+zDjldms+Y+20WBrhyuHZBtwqdqcBntTSIHgHOFZGtwLn2fUSkq4gstNcZBUwHxorIavtnUgv3q7xA7ZRBH998Bn+9sD879pdw4T+XcdOcVWQeKHG6PK9WWV3D3O+yGZMST7dYnVpJBbYWnWk2xhwAzq7n8Vxgkn37a0A/pvixILeLaSO6M3XQ0SmDPlmfx9Wndue3Y/volEH1+GzDXvYVlfOwziunlM6soFqP55RBlwxN4uVlu3TKoAbMWpFJYvtwxvQ95kBTpQKCBpFqdQkxYTx8kTVl0Ck9Ynnkw02M/X+LeW/1br0GErAjv5il2w5w5fBuuF3aWaCUBpFqMykJ0bz4s1OY88sRdIwK5XdzV3PHW2sDvnU0Z0UWQS7hslO6OV2KUl5Bg0i1uZG943j3plH8ZuxJvJGRw2XPf8Puw0ecLssRZZXVvLkyh/H9OtMpOszpcpTyChpE6oRwu4Rbx/Vl5vSh7MwvYcrTX7M0ACdT/WDtHgqOVDJN55VT6n80iNQJNa5fZ96dMYqOkSFM/88Knv9ye0CdN5q1IpNe8ZGc1quj06Uo5TU0iNQJ1zs+induGsWE/p15+MNN3DRnFcXlVU6X1ebW5xbwfdZhpo3orjNQKOVBg0g5Iio0iGevGsJdk1L5aF0eFzy7lO35xU6X1aZmr8giNMjFJUOSnC5FKa+iQaQcIyJcf0ZvZv3fCA6WVDD1maV8vD7P6bLaRFFZJe9+v5spA7vSLkIn4VTKkwaRctzIk+J4/zej6RUfyQ2vruT/fbyZaj+7RPm7q3Mprajmap1JQamf0CBSXiGxfThv3HAalw1L4plF2/j5f7/jcGmF02W1CmMMs5dn0q9rDAOT2jldjlJeR4NIeY2wYDePXjyAhy48mW+272fKM1+zPrfA6bJabGXmITblFXH1qTpIQan6aBApryIiXDUimTduOI3KKsNF/1zGO9/nOF1Wi8xekUV0aBDnD+zqdCnK24lAcBhIYF1BVoNIeaXByR14/zejGdStPb9/fQ33z19PRVWN02U12cGSCj5Yu4cLhyQSqZdVV3XVBk94e4hOgHbdICoh4C5lrn8ZymvFR4cy67oRPPrhJl74eifrcwt49qohdIrxnalx5q3MpqK6hmkjdJCCAsQFQSEQFAZBoeAO1Yv5oS0i5eWC3S7umZzOU1cOZt3uQiY//TUrMw86XVaj1NQYZq/I4pQeHejbOdrpcpQTxAXB4XaLpzO0S7JaPGHtrDDSEAI0iJSPOH9gV965aSThIW4uf345r3yzy+unBvp6234yD5TqkO1A8r/g6WAFT/tuENXJDh5t/TREg0j5jNTOMcyfMZozUuK597313PrmGq++pMTsFZnERoYwoX9np0tRbaXB4Imxgkc1igaR8intwoN54Zph3HxOH95etZuL/7WM7IOlTpf1E3kFZXy2cR+XDksiNCiwTjz7NZcLgiOs4InposHTSjSIlM9xuYSbz0nhP9cOI+tgKVOe+ZolW/KdLutH5n6XRXWNYdpw7ZbzabXBExFrBU+7bhAVbwWPO8Tp6vyGBpHyWWenJfD+jNEkRIdx7Uvf8uyibV5x3qiquoa532ZzRko8yR0jnC5HNYXLBSH1BE9otAZPG9IgUj6tR1wk79w0kskDuvLYx5u5cdZKisoqHa3ps437yCss4+oRevE7r+dyewRPVyt4IjV4TjQNIuXzIkKCeOqKQfxpcjqfbdzH1GeXsm1fkWP1zF6RSZd2YYxN7eRYDaoBPwmeJI/g0VnRnaJBpPyCiPB/o3sy+7oRFB6pZOozS/nwhz0nvI5d+0v4aut+rjglmSC3/nk5zuWGkEgNHi+nfynKr5zaqyPv/2Y0fRKi+dXsVTzy4aYTekmJ177Nwu0Srhje7YTtU3lwB1nBE9nRI3jiNHi8nAaR8jtd2oXz+g2nctWIZJ77cjvXvvgtB0va/pISZZXVvJGRzblpCST40DREfiMiFmISreAJidLg8SEaRMovhQa5eejCk3n04pP5dtdBpjz9NT/ktO0lJT5ct4dDpZU6k4JTdNYCn6VBpPza5ackM+/G0zDGcPFzy3gzI7vN9jV7eRY94yIZ2btjm+1DKX/UoiASkVgR+VREttq/OxxjXbeIfC8iC1qyT6WaakBSe97/zWiGde/A7fPWcs+7P7T6JSU25RWSkXmIq4Yn43LpJ3OlmqKlLaI/Ap8bY/oAn9v3G/I7YGML96dUs3SMCuWVXwznhjN6MWt5FpfP/Ia8grJW2/7s5VmEBLm4ZGhSq21TqUDR0iCaCrxs334ZuKC+lUQkCTgPeKGF+1Oq2YLcLu6clMazVw1hc14Rk5/+mhU7DrR4uyXlVbzz/W4mn9yFDpH6JUilmqqlQZRgjNkDYP9u6Bt8TwJ/AHzvEpvK75w3oAvv3jSKmLAgpr2wgpeW7mzR1EDvrt5NcXkV03SQglLNctwgEpHPRGRdPT9TG7MDEZkM7DPGrGzk+teLSIaIZOTne9dElsp/pCRE8+6MUZyV2ok/v7+Bm19fzZGKpl9SwhjDrOVZpHWJYUhy+9YvVKkAcNwgMsacY4zpX8/Pe8BeEekCYP/eV88mRgHni8guYC4wVkRmHWN/M40xw4wxw+Lj45t1UEo1RkxYMM9fPZTbxqUwf00uF/5zKZkHSpq0je+zD7NxTyHTRiQjOnxYqWZpadfcfOBa+/a1wHt1VzDG3GmMSTLG9ACuAL4wxlzdwv0q1SpcLmHG2D689LNT2FNQxpSnv2bR5vo+T9Vv1vJMIkPcXDA4sQ2rVMq/tTSIHgHOFZGtwLn2fUSkq4gsbGlxSp0oY/p24v0Zo0nsEMEv/vsdT32+lZrjTA10uLSCBWv3cMHgRKJCg05QpUr5nxYFkTHmgDHmbGNMH/v3QfvxXGPMpHrWX2yMmdySfSrVVpI7RvD2r0ZywaBEnvh0C9e/mkHhMS4pMW9lDhVVNTqTglItpDMrKOUhPMTNE5cN5M/n92Px5nymPrOUzXk/vaSEMYbZK7IY2r0DaV1iHKhUKf+hQaRUHSLCtSN78Nr1p1JcXsWF/1zKgrW5P1pn2fYD7NxfwjS9+J1SLaZBpFQDTukRy4LfjCatSwwz5nzPXz/YQFW19VW4WcszaR8RzKSTuzhcpVK+T4NIqWNIiAnjtV+eyjWndeffX+1k+n++ZUNuIZ9s2MulQ5MIC3Y7XaJSPk+H+ih1HCFBLh6Y2p8BSe25+50fOP+Zr6muMVw1QgcpKNUaNIiUaqRLhiaR2jmaX81eSVrnGHrGRTpdklJ+QYNIqSbon9iOL287i+oWzE2nlPoxDSKlmsjlElzodD5KtRYdrKCUUspRGkRKKaUcpV1zSinlNBEIDofgCAiJAndgvTUH1tEqpZS3cLkhJNL6CY6wwihAaRAppdSJEhR6NHyCQp2uxmtoECmlVFsROdriCYm0WkHqJzSIlFKqNbmDj4ZPcHhAd7k1lgaRUkq1hAgEhUFI7UCDYKcr8jkaREop1VQuFwR7DDRw6TdhWkKDSCmlGiMoxGrxBEdAcJjT1fgVDSKllKqPiD3IIMJq/QTYd3tOJH1llVKqljvI7nKLCPjv9pxIGkRKqcAWHHZ0RoOgEKerCUgaREqpwOJyHf1eT3CEfrfHC2gQKaX8X1CIR5dbuNPVqDo0iJRS/udHk4hG6nd7vJwGkVLKP7jcEBZzNHx0oIHP0CBSSvmHiFinK1DNpF8HVkop5SgNIqWUUo5qURCJSKyIfCoiW+3fHRpYr72IzBORTSKyUUROa8l+lVJK+Y+Wtoj+CHxujOkDfG7fr88/gI+MManAQGBjC/erlFLKT7Q0iKYCL9u3XwYuqLuCiMQAZwD/ATDGVBhjDrdwv0oppfxES4MowRizB8D+3amedXoB+cBLIvK9iLwgIpEt3K9SSik/cdwgEpHPRGRdPT9TG7mPIGAI8C9jzGCghIa78BCR60UkQ0Qy8vPzG7kLpZRSvuq43yMyxpzT0DIR2SsiXYwxe0SkC7CvntVygBxjzAr7/jyOEUTGmJnATIBhw4aZ49WnlFLKt7W0a24+cK19+1rgvborGGPygGwR6Ws/dDawoYX7VUop5SfEmOY3OkSkI/AGkAxkAZcaYw6KSFfgBWPMJHu9QcALQAiwA/i5MeZQI7afD2Q2s7w4YH8zn+tt/OVY/OU4QI/FG/nLcUDLjqW7MSa+NYtpay0KIm8mIhnGmGFO19Ea/OVY/OU4QI/FG/nLcYB/HUtj6MwKSimlHKVBpJRSylH+HEQznS6gFfnLsfjLcYAeizfyl+MA/zqW4/Lbc0RKKaV8gz+3iJRSSvkAvw4iEblfRHaLyGr7Z5LTNbWEiNwmIkZE4pyupblE5EERWWv/e3xiD/X3SSLymD2j/FoReUdE2jtdU3OIyKUisl5EakTEJ0dqicgEEdksIttEpMEvzHs7EXlRRPaJyDqnazmR/DqIbH83xgyyfxY6XUxziUg34Fys72v5sseMMQOMMYOABcC9DtfTEp8C/Y0xA4AtwJ0O19Nc64CLgCVOF9IcIuIGngUmAunAlSKS7mxVzfZfYILTRZxogRBE/uLvwB8Anz6pZ4wp9LgbiQ8fjzHmE2NMlX13OZDkZD3NZYzZaIzZ7HQdLTAc2GaM2WGMqQDmYl0ZwOcYY5YAB52u40QLhCCaYXedvNjQhfu8nYicD+w2xqxxupbWICJ/FZFsYBq+3SLy9AvgQ6eLCFCJQLbH/Rz7MeUjjjvpqbcTkc+AzvUsuhv4F/Ag1qfuB4HHsd4wvM5xjuMuYNyJraj5jnUsxpj3jDF3A3eLyJ3ADOC+E1pgExzvWOx17gaqgNknsramaMxx+DCp5zGfbWkHIp8PomPNDu5JRP6NdU7CKzV0HCJyMtATWCMiYHX/rBKR4faEsl6nsf8mwBzgA7w4iI53LCJyLTAZONt48XchmvBv4otygG4e95OAXIdqUc3g111z9qUpal2IdVLWpxhjfjDGdDLG9DDG9MD6oxvirSF0PCLSx+Pu+cAmp2ppKRGZANwBnG+MKXW6ngD2HdBHRHqKSAhwBdaVAZSP8OsvtIrIq8AgrGb6LuCG2ivK+ioR2QUMM8b45CzDIvIW0BeowZpZ/UZjzG5nq2oeEdkGhAIH7IeWG2NudLCkZhGRC4GngXjgMLDaGDPe0aKayP5qxpOAG3jRGPNXZytqHhF5DRiDNfv2XuA+Y8x/HC3qBPDrIFJKKeX9/LprTimllPfTIFJKKeUoDSKllFKO0iBSSinlKA0ipZRSjtIgUn5BRDp6zLKe5zHrerGI/LMN9nejiFzTxOcs9tXZrZVqSz4/s4JSAMaYA1jfGUNE7geKjTH/rw3391xbbVupQKMtIuXXRGSMiCywb98vIi/b10HaJSIXicjfROQHEflIRILt9YaKyJcislJEPq4zQwce27rNvr1YRB4VkW9FZIuInG4/Hi4ic+1Jd18Hwj2eP05EvhGRVSLypohEiUh3EdkqInEi4hKRr0TEZ+YYVKq5NIhUoOkNnId1mYBZwCJjzMnAEeA8O4yeBi4xxgwFXgQa8y39IGPMcOBmjs6d9yug1L5e0V+BoQD2hQ3vAc4xxgwBMoBbjDGZwKPAc8CtwAZjzCctP2SlvJt2zalA86ExplJEfsCaDuYj+/EfgB5Y0w/1Bz61J5l1A42ZFupt+/dKezsAZwBPARhj1orIWvvxU7Eu4LbU3kcI8I293gsicilwI3ZXo1L+ToNIBZpyAGNMjYhUesyYXYP19yDAemPMac3ZLlDNj/+u6ptDS4BPjTFX/mSBSARHL7AXBRQ1sQ6lfI52zSn1Y5uBeBE5DUBEgkWkXzO3tQTr4n+ISH9ggP34cmCUiJxkL4sQkRR72aNY1zW6F/h3M/erlE/RIFLKg32p6UuAR0VkDbAaGNnMzf0LiLK75P4AfGvvIx/4GfCavWw5kCoiZwKnAI8aY2YDFSLy8xYcjlI+QWffVkop5ShtESmllHKUBpFSSilHaRAppZRylAaRUkopR2kQKaWUcpQGkVJKKUdpECmllHKUBpFSSilH/X/lXXyA9Ujt3QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAAEkCAYAAABt4jWqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABEWElEQVR4nO3deXxU9b3/8ddnJvuekAVIgLAkQAAhgLhBKwKKK0Lrdava2/aq19rN2muv9nbR1p/e3i7XbkKtvVqxtqVQN9QaintFA7KHfQ1LNrKHLJP5/v44JzjEhITMJGcy83k+Hnkkc+bMOd+TwLzn+z3f8zlijEEppZRyisvpBiillApvGkRKKaUcpUGklFLKURpESimlHBXhdAOUUirYrV+/PjMiIuIJYDL6Ad4fXmCrx+P50owZM8o7FmoQKaVUDyIiIp4YOnToxIyMjGqXy6VTjfvI6/VKRUVFwfHjx58ArulYrsmulFI9m5yRkVGnIeQfl8tlMjIyarF6lh8vd6g9Sik1mLg0hALD/j2elj0aREopNUjt3LkzKi8vb5LT7ehs1qxZ499666243q6vQaSUUuqUtra2Ad+nBpFSSg0S3//+97Py8vIm5eXlTXrwwQczATweD0uWLMnNz88vWLhw4Zj6+noXwF133ZU9duzYSfn5+QW33357DsDRo0cjLrvssrGTJ0+eOHny5Il///vf4wHuueee4TfeeOOoiy66KG/JkiWjzznnnAnFxcUxHfudNWvW+Lfffjuurq7Odd111+VOnjx54sSJEwueeeaZFICGhga56qqrxuTn5xdceeWVY5qbm+VsjktnzSml1Fn41opNI3Ydr+/1sFNv5A9NbPrxZ6cePtM6b7/9dtyzzz47ZP369SXGGGbMmDFx3rx59QcOHIhZunTpgUsvvbTxuuuuy/3xj3+c8eUvf7ly9erVqfv27dvqcrmorKx0A9xxxx0j7rnnnrLLLrusYffu3VGXXXZZ3r59+7YBbN68OW7dunU7EhISzA9+8IPM5cuXp82cOfPowYMHI8vLyyPnzJnTdPfdd2fPnTu37i9/+cuByspK98yZMydec801dT/96U8zYmNjvbt27dq+bt262IsuuqjgbI5fe0RKKTUIvPHGGwlXXHFFTVJSkjc5Odl75ZVXVq9duzZx6NChrZdeemkjwC233FL13nvvJaSlpbVHR0d7b7jhhlFPPfVUSkJCghfg3XffTfra1742csKECQVXX331uIaGBnd1dbULYOHChTUJCQkG4NZbb61+4YUXUgGefvrp1KuvvrrabkPSz372s2ETJkwomD179viWlhbZs2dP1DvvvJNwyy23VAGcd955J/Pz85vO5ti0R6SUUmehp55Lf+nuTgki8onHkZGRbNy4seSFF15Ieu6551J/85vfZL7//vu7jDEUFxeXdASOr/j4eG/Hz6NHj25LSUnxrFu3LnblypVpS5cuPdjRhhUrVuyZOnVqS0/tOBvaI1JKqUHgkksuaVi9enVKfX29q66uzrV69erUuXPn1h87diyqqKgoHuDZZ59Nu/DCCxtqa2tdJ06ccF9//fW1jz/++OGSkpI4gNmzZ9c9+uijmR3bfO+992K7299nP/vZEw8//PDQ+vp696xZs04CzJ07t+4nP/lJltdrZda7774ba2+34ZlnnkkD+PDDD2N27dp1VkOXGkRKKTUIzJ49u+mmm26qmj59+sQZM2ZMvOWWWyrS09Pbx4wZ0/zkk08Oyc/PL6iuro649957K2pqatwLFy7My8/PL5gzZ874H/7wh4cBli1bdnjDhg3x+fn5BWPHjp30y1/+MqO7/X3uc5+rfvnll9MWLVp0omPZI488ctTj8ciECRMK8vLyJn3nO9/JBrj33nvLGxsb3fn5+QUPP/zw0ClTpjSezbGJ3hhPKaXObNOmTQemTp1a6XQ7QsWmTZvSp06dmtvxWHtESimlHKVBpJRSylEaREoppRylQaSUUspRGkRhSEQOiMh8B/YrIvKoiFTZX/8t3Vx8ICK5ImJEpMHn6798np8rImtFpFZEDnSzja+JyH4RaRSREhHJt5ff32m7J0XEKyLpPq+dLyIb7NceFpF/sZeni8i7dvtrROSfInKRz+se77TtFhGp76JteSLSLCLPdFoeJyK/FpFK+9je8nnulU7bbhWRLT7PrxWRChGpE5FNIrKo07a/Yv8+6kSkWERmd9GuNHsb7/gs6+mYbxCRnXZ7y0XkKRFJ6vS3XC0i1SJyXER+KSJ6DaM6RYNIDaTbgWuBqcA5wFXAHT28JsUYk2B/PeSzvBF4EvhWVy8SkS8BXwSuBBLsfVUCGGMe9tlmAvAo8IYxptJ+bQHwLPAAkAxMA9bbm24AvgBkAKn2a1/seGM1xtzZadt/BP7SRRN/BXzYxfJlQBow0f7+jY4njDGXd9r2e522/TVgmDEmCet3/YyIDLOP6TzgEeCz9jH9DlglIu5O+38UKOm07IzHDLwLXGSMSQbGYF0o/0Of1/8aKAeGYf0uPw3c1cWxqzClQaROEZFoEfm5iBy1v34uItH2c+ki8pL9ifiEiLwtIi77uftE5IiI1NufjOd1s4vbgJ8YY0qNMUeAnwCf70tbjTEfGGP+AOzr4jhcwPeAbxhjthvLXmPMiS7WFeAW4Cmfxd8BlhpjXjHGeIwxVcaYvfZ+m40xO40xXkCAdqw357Quth0PfKbTthGRG4AaYE2n5eOx7lp5uzGmwhjTboxZTxdEJBeYA/zB53ey2Rjj6XgIRAIj7Me5wDZjzHpjXbPxNJAOZPps8wKsG5b93ndfPR2zMeZwR4jb2oFxPo9HA3+2t3MceBUIulsXhJOXXnopce7cueMAli9fnnz//fcP7W7dyspK9yOPPNLt9Ubdueeee4Z/97vfzerNuhpEytcDwPlYn1qnArOw3pQBvgmUYn0qzgLuB4z95nk3cK4xJhG4DDjQzfYnAZt8Hm+i5zekgyJSKiK/9x0660GO/TXZHlbbLyI/6AjOTubYx/NXn2XnA4jIFhE5JiLPiMhpQSMim4Fm4AXgCWNMeRfb/gxQAfgOryUBD2L9Pjs7DzgI/MAemtsiIp/p5hhvBd42xuzv1K6XRKQZWAe8ARTbT70CuEXkPLsX9AVgI3Dcfp0bq5d2N1aIfcKZjllEZotILVBvH/fPfV76v8AN9rBjNnA5VhipAPN4PD2v1MnNN99c+/DDDx/v7vmqqir37373u8zung8EDSLl62bgQWNMuTGmAvgBVm8BoA1raGWUMabNGPO2/cm6HYgGCkQk0hhzoKP30IUEoNbncS2QYPdKOqsEzgVGATOARGB5L48jx/5+KTAFmAvciDVU19ltwApjTEOn19+C9YaaB8QCv/B9kTHmHCAJuAl4h67dBjxtTr9q/CHgd8aYruqV5WD1SGqB4Vih8JSITOxi3VuB/+u80BhzFdbv6grgNbsXA1ZA/NVuawtWj/F2n7Z9FVjXXQ/M3na3x2yMeccemssBfszpH0bexPrAUYf1YaYY+Ft3+1Fd27lzZ9To0aMndb7lQ3Z29pR777132IwZM8Y/+eSTqStXrkyaNm3ahIKCgomXX375mNraWhfAihUrkkaPHj1pxowZ41esWJHSsd3HHntsyK233joS4PDhwxELFiwYO378+ILx48cXvP766/Hf/OY3cw4fPhw9YcKEgjvuuCMH4L/+67+yJk+ePDE/P7/gG9/4xvCObd13331Dc3NzJ1944YX5u3fvju7tsekJQ+VrONYn8g4H7WVgvbl8H/i7nRvLjDGPGGP2iMjX7ecmichrwD3GmKNdbL8B642sQxLQ0OmNGgA7GDo+zZeJyN3AMRFJMsbU9XAcJ+3v/22MqQFqRGQp1pvzbztWEpFY4DpgURev/70xZpe93sNAURdtbAb+KNZEiI3GmFO9PREZgXUu5N98lk0D5gOFZ2h3G/BDe4jtTRFZixWop87b2JMMhgIrutqIMaYNeEWsyRp7jTEvAF/C6gVNAvbY23xJRDra8lWswD+jMx2z/fwREXkVeA6YbvdCXwOWAhdifRh5Eus803/0tL+g9Lcvj6B8e0BvA0FmQRPX/qrHYqpd3fIBICYmxrt+/fqdx44di7j66qvHvvXWW7uSkpK8DzzwwNCHHnoo68EHHzx+9913577++us7J02a1HLVVVeN6Wr7d95558g5c+bUf/e7393r8Xiora11/+QnPym96qqrYnfs2LEdYOXKlUl79uyJ2bx5c4kxhvnz54975ZVXEhISEryrVq1K27Jly/a2tjamTZtWUFhY2Ksq3BpEytdRrB7INvvxSHsZxph6rOGkb4rIJGCtiHxojFljjHkWeNYedlqK9SZzyye2bm13KvCB/Xiqz7560hFWvSnxuxNopZshJh9LgBNYQ1i+Nvfitb4isU7S+74p3wq8Z4zxPYd1Mda5mkN2mCdgDZcVGGOm2/vtjduAlZ16cV2JAMbaP08FXuwIV+BVETmGFQ4erN7udrtdsUCsiBwHso0x7V1su6tj7mq/aVjnqX5pjGkBWkTk91iTGQZnEDmo8y0fHnvssUywbtsA8MYbb8Tv3bs3ZtasWRMA2traZMaMGQ0bN26MycnJaZkyZUoLwM0331z1xBNPfOK8z3vvvZe4YsWK/QAREREMGTKkveNeRh1effXVpLfeeiupoKCgAKCpqcm1Y8eOmPr6etcVV1xRk5iY6AW49NJLa3p7XBpE4StSRGJ8HnuwZnh9R0Q+xHoj/i7wDICIXAXsAPZiDbG0A+32OaJsrJlTzVif6rsb8n0auEdEVtvb/yadhrw62LO8aoDdWCfGH8Oa2VZrP+8CorDeEMU+Fq8xptUY0yQifwL+Q0Q+wpol9m9YvTpfXQ2dgXWy/r/Emlp9HLgPeMne7/lY/28+ANxYPYksrHMyvm7FCmRfy7B6Ch3uxQqmf7cfvwUcAv5TRP4f1jmji/GZGejTi1vS6fc1AWtSwBtYf8vrgU/x8Zv9h8ADIvILYD9Wzywf2Go/zvXZ3PVYw2+LjDHtPR2ziNwMvA0cxvrw8iPsiRjGmEoR2Q/8u4j8D1b43kbXATY49KLn0l86j2J3PO548zfGMHv27LoXX3zxtHOH7733XmzXI+BnzxjD17/+9WPf+ta3Tqu99+CDD2b2dR96jih8rcYKjY6v72N9Si3G+mS+BdjAx9Nw87CGpxqAfwK/Nsa8gXV+6BGsczrHsWZh3d/NPpcCL9rb3gq8bC8DQES22W9qYH3afhXr3MZWrPMaN/ps61N2u1djvfmdBP7u8/zddluP2u19FmtIqGNf2cAlWOF4GmPMk/bydVjDky1Yb77Yx/sroAo4gjXcd6XvUKQ9+yyHTtO2jTFNxpjjHV92+5rt83EdQ2qL7G3WYg0j3mqM2eGzmWvt59Z2arZg/Q3LsSZIfA243hizwX7+aawQfAPrg8RjwB3GmB3GmJZO7aoF2uyfe3PMBVhTyRuwPpDsxGdIEis0F9rt2oMVlN9AnbWubvng+/zFF1/cWFxcnLB169ZogPr6etfmzZujp02b1lxaWhq1bdu2aIDnnnvuE7M8AS666KL6juE+j8fDiRMnXMnJye2NjY2nsuLyyy+v+8Mf/pDece5p//79kUeOHIm45JJLGl5++eWUhoYGqa6udr3++uspvT0urb6tlFI9CIbq2zt37oy64oor8s4777z64uLihNGjR7esWLFi/4QJEyYVFxeXDBs2zAPwwgsvJN5///05ra2tAvC9733vyM0331y7YsWKpG9961sj0tLSPOedd15DSUlJ7Nq1a/c89thjQ4qLi+OffvrpQ4cPH474/Oc/P+rw4cPRLpeLX/7ylwfnz5/fePXVV4/esWNH3CWXXFK7dOnS0oceeijzD3/4QzpAXFycd/ny5fsnTZrUct999w3905/+lJ6dnd0yfPjwtokTJ5588MEHyzofS+fq2xpESinVg2AJoquuuipv9+7dvT2vGrT0NhBKKaWCigaRUkoNAuPHj28Nhd5QVzSIlFJKOSqop2+np6eb3Nxcp5uhlApzjz76KNu2bRsVqCnQ/amlpcVTWFgYtNPjvV6vAF7fZUEdRLm5uRQXF/e8olJK9aP9+/eTmJjIkCFDPnEtT7DZunVrq9Nt6I7X65WKiopkrEsyTgnqIFJKqWCQk5NDaWkpFRUVTjelR8ePH49ob2/vbYHggeYFtno8ni/5LtQgUkqpHkRGRjJ69Ginm9ErBQUFW4wxM51ux9nQyQpKKaUcpUGklFLKURpESimlHKVBpJRSylEaREoppRylQaSUUspROn1bKaUGkjHgbQevB0y79bNpB6/342XxmeAOn7fn8DlSpZTqD70JllPL2q31exLn7XmdEKJBpJRSvjqCxdjh4hssn1jWy2BRZ6RBpJQKbacFi2/PRYMlWGgQKaVCQ0sDtDV1MUymwRLsNIiUUqGh7SQ01zndCtUHOn1bKaWUozSIlFJKOUqDSCmllKM0iJRSSjlKg0gppZSjNIiUUko5SoNIKaWUozSI1IDYV9HAs+sOOd0MpVQQ0iBSA+IX/9jD/au2sLus3ummKKWCjAaR6needi9rd5YDsPKjIw63RoWidq/h3f21TjdD9VFAgkhEForIThHZIyLfPsN654pIu4h8NhD7VYPD+oPV1DS1kRQTwfMfHcHr1dpfKrD+uqGUm5fvpPhIk9NNUX3gdxCJiBv4FXA5UADcKCIF3az3KPCav/tUg0tRSRlRbhffvnwiR2ubWbf/hNNNUiHmqnOGkRobweMfVDvdFNUHgegRzQL2GGP2GWNageeARV2s9xXgr0B5APapBpE1JeWcP3YIiwuzSYiOYNVHpU43SYWYuKgIbpmZSdHeBvZUtTjdHHWWAhFE2cBhn8el9rJTRCQbWAw83tPGROR2ESkWkeKKiooANE85aW9FA/sqG5k/MZPYKDcLJw/llS3HaW5rd7ppKsTcNjOL6Ajht8XaKxpsAhFE0sWyzicBfg7cZ4zp8d3HGLPMGDPTGDMzIyMjAM1TTlpTUgbAvIlZACwpzKa+xcPr28ucbJYKQUPiI7luUjKrttdR3uBxujnqLAQiiEqBET6Pc4CjndaZCTwnIgeAzwK/FpFrA7BvFeSKtpczcVgS2SmxAJw/ZgjDkmNYpbPnVD/40sxUPF7D7zdor2gwCUQQfQjkichoEYkCbgBe8F3BGDPaGJNrjMkFVgB3GWP+FoB9qyBW3dhK8cETLJiYeWqZyyUsmpbNm7sqqGzQsXwVWLmpUSzMS+CZTTU0tHqdbo7qJb+DyBjjAe7Gmg1XAvzZGLNNRO4UkTv93b4avNbuLMdrYH5B1mnLl0zPpt1reHFT546zUv6749w06lu8PLe5xummqF4KyHVExpjVxph8Y8xYY8yP7GWPG2M+MTnBGPN5Y8yKQOxXBbeikjIyE6OZPDz5tOX5WYlMGp6kw3OqX0wdFst5ObH8bn01be16zdpgoJUVVL9o8bTz1q5K5k3MwuX65HyWxYXZbC6tZU95gwOtU6HuzllpHKv38OKOOqebonpBg0j1i3X7TtDQ4mG+z/khX9dMHY5L0GuKVL+4eHQ8+UOiWPbhCYzRXlGw0yBS/aKopIyYSBcXjUvv8vnMpBhm52Xwt4+OaskfFXAiwu3nprGjspU3D2jZn2CnQaQCzhjDmpJyZo/LICbS3e16SwqzOVJzkg8OaMkfFXjXTExiaEIEyz7Uf1/BToNIBVzJsXqO1JxkQUHXw3IdLp2URVyUm1UbdNKCCrwot/CFGam8d6iJLcebnW6OOgMNIhVwa0rKEIFLJmSdcb24qAgWTh7K6i3HtOSP6hc3npNMYpSLpdorCmoaRCrgikrKmJqTQkZidI/rLinMob7Fw5oSrYWrAi8x2s1NU1NYvaueQzWtTjdHdUODSAVUeV0zm0prWVBw5t5QhwvGDiErKVpnz6l+84UZqbgFfrdey/4EKw0iFVBrdlg9m3ndTNvuzO0Srp2WzRs7K6jSkj+qH2QlRHBtQRJ/2lLLiSYthhqMNIhUQBVtLyMnNZbxWYm9fs3i6dl4vIaXNh/rx5apcHb7zDSaPYY/bKxxuimn87aDpwVaG6G5FpqqoKEM2tucbtmA0iBSAXOytZ139lQyf2IWIl3dHaRrE4YmMXFYEiu15I/qJ3np0cwbE89TH9Vwsm0Ai6EaY4VK20loqYeT1dBQAXXHoOYw1JZC/XForISTNdDSAG3NfPJOOqFNg0gFzDt7KmnxeJk/sXfnh3wtKcxm0+Ea9lZoyR/VP24/N40TJ9tZsa02sBv2esDTDK0NVpg0VlrhUlsKNYeg7ig0lEPTCWiug7YmaG8Fo9XBO2gQqYBZU1JGYnQEs0annfVrr5lmlfz5m/aKVD+ZlRPLtGExPFFcTfvZVPMwXis4WpusIGk6YQVL3VEraGqPQH0ZNFZZw2utjdZwm1cvSegtDSIVEF6voaiknE+NzyAq4uz/WWUlxXDRuHRWfXRES/6ofiEi3HluGgdr2nhtt0/Pu6vhs8YKqD8GtYetIbS6Y9ayk9XWOm0nrddoHbuA0CBSAbGptIbKhhYW9GFYrsPiwmxKq09SfFCn2aoAsycFLBgp5Ca7WbquHFN/HOqOdD181toEnlbw6vDZQNAgUgGxpqQct0u4eHxGn7dx2aShxEa69Zoi1TdeT4+TAtwnq/i3ycKmcg/rDp+Edp3OHQw0iFRAFJWUMXNUKilxUX3eRny0VfLnpc1a8kf1wcnaXk0K+Mw4N0NiYOlmDaFgoUGk/Hb4RBM7jtf3uprCmSwuzKa+2cPaHVryR/WPmAjh8wURrC31svOEDr0FAw0i5bc1JWUAzPPj/FCHi8alk5kYrdcUqX71uYluYiNg2VbtFQUDDSLlt6KScsZmxDM6Pd7vbbldwqJpw3ljZznVjVqkUvWP1Bjh+nw3L+z1cqxRZ745TYNI+aWuuY11+6uYH4BhuQ6LC3Noaze8tPlowLapVGdfnOzGa+D327RX5DQNIuWXt3ZV0NZu+lRNoTsFw5OYMDRRh+dUvxqR6OLK0S6e3dFOXav2ipykQaT8sqaknNS4SKaPTA3odhcXZvPRoRr2VzYGdLtK+bp9SgQNbfDsDp2l6SQNItVnnnYv/9hRztwJmbhdvS9y2huLpmUjAqu0V6T60eR0F7OHu3hym4eWdu0VOUWDSPVZ8cFqak+2+VVNoTtDk2O4aGw6f/voCEbLqKh+dPsUN+VN8Pxe7RU5RYNI9dmakjKi3C7m5Pe9msKZXFuYzaETTazXkj+qH83JdjExTVi2pR2vfuhxhAaR6rOiknLOHzuEhOiIftn+wslDiYl06aQF1a9EhDumRLCnxrD2sF7g6oSABJGILBSRnSKyR0S+3cXzN4vIZvvrPRGZGoj9KufsrWhgf2UjC3p5S/C+SIiO4LJJQ3l58zFaPDpsovrPlWNcZMdr2R+n+B1EIuIGfgVcDhQAN4pIQafV9gOfNsacAzwELPN3v8pZRdutagqX9MP5IV+LC7OpPdmmJX9Uv4p0CV+YHMEHZYYN5dorGmiB6BHNAvYYY/YZY1qB54BFvisYY94zxnQM9L8P5ARgv8pBRSVlFAxLIjsltl/3M3tcOukJ0azcoMNzqn/dMN5NchQs26K9ooEWiCDKBg77PC61l3Xni8Ar3T0pIreLSLGIFFdUVASgeSrQTjS2sv5gdUCrKXQnwu1i0bThrN1ZTk2TlvxR/Sc+UvjcRDevHfCyv1Z7RQMpEEHU1QUkXU49EZG5WEF0X3cbM8YsM8bMNMbMzMjon9lYyj9rd5TjNTC/H88P+VpcmG2X/Dk2IPtT4eu2gggi3fDbrXpOciAFIohKgRE+j3OATxQJE5FzgCeARcaYqgDsVzlkzY4yspKimTw8eUD2N2l4EvlZCXpxq+p3mXHCZ8a5WbG7nYqTOpV7oAQiiD4E8kRktIhEATcAL/iuICIjgZXALcaYXQHYp3JIi6edN3dWcMmELFwBrqbQHRFhcWEO6w9Wc7BKS/6o/vWlyW7a2uHp7XquaKD4HUTGGA9wN/AaUAL82RizTUTuFJE77dW+CwwBfi0iG0Wk2N/9Kmes23eCxtZ2FhQMzLBch2sLh2vJHzUgxqa4WDDKxdPb22ls017RQAjIdUTGmNXGmHxjzFhjzI/sZY8bYx63f/6SMSbVGDPN/poZiP2qgVdUUkZMpIsLx6YP6H6HJcdywZghrNKSP2oA3DElgtpW+PMuB84VGWN9hRGtrKB6zRhD0fYy5uRlEBPpHvD9Ly7M5mBVExsO1Qz4vlV4mZHlYmaW8MRWDx5vP4WCMdDeBm1N0FwHTVXQUAb1x8Db1j/7DFIaRKrXSo7Vc7S2uV+KnPbGwslDiY5wseqjUkf2r8LLHVMiONIAL+/3cyr3qcA5aQfOCWgotwKnsQJO1kBrA3hawBues/U0iFSvFZWUIQJzJwzs+aEOiTGRXDppKC9tPkarR6/zUP1r3kgXY5OFZVs8vRsO9g2clno42Tlwqu3AaQavToTwpUGkem1NSRnTRqSQkRjtWBuWFGZT09TG2p1a8kf1L5cIt09xs63K8O7RTh98egqclnpo08DpLQ0i1Stldc1sKq0N6C3B+2JOXjrpCVGs0pI/agBcO8aQEQtLN7VYAdNQDnVHNXACTINI9cqaEqsH4nQQRbhdXD11OP/YUU5tU3id0FX9yNtuDZm1NFgB01gB9ceIPlnBv45r4u1jLrZVtGrg9BMNItUra0rKGJEWS35WgtNNYUlhDq3tXl7a8okCHkqdmbfdmhTQ0mBNErADh4YyaxJBS5015NbedmoK9c1jWoiPMPx2Z4yzbQ9hGkSqRydb23lnTyXzJmQhMjDVFM5kcnYS4zIT+Jte3Kq60xE4rY124FRC/XE7cKrswGk6LXC6kxxluHFMMy8ejqK0Ud8y+4P+VlWP3tlTSYvHy4IBqLbdG1bJn2w+PFDN4RNNTjdHBYvWRmjqFDjNtXbgtILp+0zLL+S1IMCTu7VX1B80iFSPiraXkRgdwbm5aU435ZRrC607jWjJH3WKpwU8/gVOd4bHeblmZCvP7YumttX5UYFQo0GkzsjrNazZUc6nx2cQFRE8/1yyU2I5f0yalvxRA+bf8ptpahee2evc5QuhKnjeWVRQ2lRaQ2VDS9AMy/laUpjD/spGNh6ucbopKgxMTGnn00Nb+f3uGJrDswBCv9EgUmdUVFKG2yVcnO9MNYUzWTilo+SPDs+pgXHH+GYqW1ysOqi9okDSIFJntKaknHNzU0mOi3S6KZ+QFBPJ/IIsXtx0VEv+qAFxQYaHKakefrszhv6qhRqONIhUtw6faGLH8XrHL2I9kyWF2VQ3tfHmrgqnm6LCgAjcPr6ZfQ1uXj8afB/OBisNItWtopIyAOYFcRB9Kj+DIfFRWpFbDZjLs1sZEd/OUr3ANWA0iFS31pSUMy4zgdHp8U43pVuRdsmfopJyak9qyR/V/yJc8KX8ZjZURVJcGeF0c0KCBpHqUl1zG+/vq2LexOCbpNDZ4sJsWj1eVm855nRTVJi4LreF1Cgvj2uvKCA0iFSX3tpVgcdrHLsJ3tk4JyeZMRnxWpFbDZi4CLhlXAtFR6PYU6dvo/7S36DqUtH2MtLioygcmep0U3okIiwpzOaDAye05I8aMLeNaybaZfjtLu0V+UuDSH2Cp93L2p0VzB2fids1OMqZLJpmlfx5fqP2itTAGBJtuG50C6sORlN+cnD8PwlWGkTqE4oPVlN7so0FBcF/fqjDiLQ4Zo1OY6WW/FED6Ev5zXi88Ps92ivyhwaR+oSi7WVEuV3MyctwuilnZUlhNvsqGtlcWut0U1SYyE3wsjCnjWf2RtOgkzb7TINIncYYQ1FJGReMHUJ89OCamnr5lGFEackfNcDuGH+S+jYXz+3XXlFfaRCp0+ytaORAVRPzB8G07c6SYyNZMNEq+dPWriV/1MCYmtbOeRlt/G5XNG36z65PNIjUadYMgmoKZ3JtYTZVja28pSV/1AC6c3wzx066efFwlNNNGZQ0iNRpikrKKBiWxPCUWKeb0iefzs8gNS6SlTo8pwbQxUPbyE/ysGxnTE93HlddCEgQichCEdkpIntE5NtdPC8i8pj9/GYRmR6I/arAOtHYyvqD1cwPwnsP9VZUhFXy5/XtZdQ169ljNTA6iqHuqI3gzTIthnq2/A4iEXEDvwIuBwqAG0WkoNNqlwN59tftwG/83a8KvLU7yvEaBkU1hTPpKPnzipb8UQPompGtDI31ssyfsj/GQP1x2Pdm4Bo2CARiWtQsYI8xZh+AiDwHLAK2+6yzCHjaWBd4vC8iKSIyzBij7xRBpKikjKykaCZnJzndFL9MG5HC6PR4Vm44wvXnjnS6OSpMRLngC3nNPLw5ji3Vbqak9uI2rs21cGIf1ByA6kNQdxg8zRCdBJOutbpaYSAQQZQNHPZ5XAqc14t1sgENoiDR4mnnrV0VLCrMRgb5P34RYXFhNj99fRel1U3kpMY53SQVJm4c08wvtsewdGcMvzy/8fQn205C9UGo2Q81h6DmMLTUfPx8fCZkTYKUUVB464C222mBCKKu3rU6n67rzTrWiiK3Yw3fMXKkfpodKO/vO0Fja/ugnLbdlY4gen7jUb48d5zTzVFhIjESbhrbwpM7Izg28hDDmvdCzUGoPQwNZR+vGJ0MySMg9UJIGQOpoyDSZ4LQ8Klh0xuCwARRKTDC53EOcLQP6wBgjFkGLAOYOXOmzj8ZIEXby4iNdHPh2HSnmxIQI9LiODc3lZUbSrnr4rGDvpengpgx0FgB1fuhej/frD7M16OPErveniwTEQNJOTBmEqTmQtoYiElxssVBJxBB9CGQJyKjgSPADcBNndZ5AbjbPn90HlCr54eChzGGNSVlzMlLJybS7XRzAmZxYQ73r9rC1iN1TMlJdro5KlS01MGJ/fZ5nYNQVwptdtV3cROVNIz34y7gxfp8vvWpoaSkZYVV76Yv/A4iY4xHRO4GXgPcwJPGmG0icqf9/OPAauAKYA/QBPyrv/tVgbP9WB1Ha5v5+vx8p5sSUFdOGcb3X9jGyo9KNYhU33haoPqAFTo1B63zOs3VHz8flw4Z4yElF1JHW8Nt7kiG1LlY/loKmWVNfG1Is0ONHzwCUkzMGLMaK2x8lz3u87MBvhyIfanAW1NSjgjMnRAa54c6JMdFMm9iJi9uOsoDV0wkwq3Xb6szMF6oLbWG2GoOWKHTUMap09lRiVbQjDzfDp5ciIrvclN5SV7mDWvlqd0x3J7fTOzgKts44PTXoygqKWPaiBQyEqOdbkrALS7M5pWtx3l7d2XIBa3ygzHQWAnV++yezkGoOwpe+7yOOxqSsmH0p62eTmouxA05q13cPr6Z699IYsXBaG4Z2xL4YwghGkRhrqyumc2ltXzrsvFON6VfXDw+kxS75I8GURhrafj4ep0a+7xOqz29WtyQOBSyZ1g9nbQx1mPxrwc9K93DtDQPT+yM4aYxLbj1NFG3NIjC3JqScgAWDOKyPmcSFeHiqnOG8ZfiUuqb20iM0fIrIc/TCrUH7QkFh6D2EJw88fHzcekwJM+6XidtDCSNgIjAFysVsW4R8e//TOS1I5FckaMlp7qjQRTmikrKGJEWS15mgtNN6TeLC3N45v1DvLL1OP8yc0TPL1CDz5H18M5PoXwr1JcB9v0YohKs8zo550LKaEgb3e15nf5waXYbuQntLN0Ry+XZbTp5rhsaRGGsqdXDu3squem8kSF9nc30kSnkDolj1YYjGkShqv4YHHwLEofD6DlW6KSOhvizO68TaG6xbif+nQ3xrKuM4PwMj6PtCVY6jSiMvbO7khaPl/mDvMhpT0SEawuzeX9/FUdrTjrdHNUf8i+Hm1bC7G/C5OsgZ6bjIdThs7ktDIn2snSH3sG1OxpEYayopIzEmAhmjU5zuin9bnFhNsbA3zbqfYpCksvt9+SC/hLjhtvGNbP2eBQ7a0PngvFACs6/nOp3Xq/hHzvKuXh8JpFhcH3NqCHxzBiVyqoNRzB65zI1wG4Z20Ks2/h3i4gQFvrvQKpLG0trqGxoDZkip72xuDCb3eUNbDta53RTVJhJjTZcP7qFFw5FcawpdM/H9pUGUZhaU1KG2yVcnB8+QXTVOcOIcrtYpbcRVw74Yn4zXuD3u7VX1JkGUZgq2l7OubmpJMeFz3U1KXFRzJ2QwfMbj+Jp9zrdHBVmRsR7uSKnlWf3xVDXpr0iXxpEYejwiSZ2ltWH/Gy5riwuzKGyoYV39lQ63RQVhu4Y30yDR3h2b+iV0/KHBlEYKiqxbtAVqtUUzmTuhAySYyN1eE45YnJqOxdltvH73TG09OJO4uFCgygMFZWUMS4zgVFDBu4K82ARHeHmynOG8dq24zS06MWFauDdMf4kZc0unj8U+LJCg5UGUZipa25j3b4TYTks12FJYTbNbV5e3Xrc6aaoMDQny8PEZA+/3RWDV68kADSIws6bOyvweE1YTdvubMaoVEamxbHqo1Knm6LCkFUMtZnddRGsPRY+k4XOJCSD6ERjK206K6pLa0rKSIuPonBkqtNNcUxHyZ/39lZxrFZL/qiBd+WIVrLj2lmqF7gCIRhENU2tXPG/b/Pfr+5wuilBp63dyz92lHPJhEzcrvCePtpR8uf5jUedbooKQ5Eu+EJeMx9URvJRlZb9CbkgSomLYkFBFr99ez+vbdNzAL6KD1RT1+wJ62G5DqPT4ykcmaIlf5RjbhjTQnKkl2U7Y51uiuNCLogAvnPVRKZkJ3PvXzZxqKrJ6eYEjTUlZUS5XczJy3C6KUFhSWE2O8vq2X5MS/6ogRcfAZ8b28KrRyLZXx+Sb8W9FpJHHx3h5tc3TwfgrmfX09ymE/aNMRSVlHHB2CHER+ttqACuOmc4kW5h1Qa9pkg547a8ZiJd8MSu8D5XFJJBBDAiLY6fXDeVrUfq+OHL251ujuP2VjRyoKqJ+WF4EWt3UuOjuHh8Js9vOkq7zqNVDsiMMXxmVAt/ORBNZXP4nrcN2SACuHTSUG7/1Bieef8Qz4f5fWg6qinMm6Dnh3wtKcymor6Fd7Xkj3LIl8Y30+aFp/eEb68opIMI4FuXjWfmqFT+c+UW9pQ3ON0cxxRtL2PS8CSGp+iJUV9zJ2SSFBOhJX+UY8YmelkwvI2n9sbQKPEQlwau8Lq+KOSDKNLt4hc3FRIT6eau5es52Rp+54uqGlrYcKg6rKspdCcm0ir58+rW4zRqyR81kFwREBUPsancURhLbavw50PxEBFjXfUaRkI+iACGJcfy8+unsbu8ge/8bWvYTdddu7MCr0GDqBuLC3M42dau0/1V/3JFQFQcxKZCQhYkZEJMMkTGMmNoJDOzhCe2evCE4fnKsAgigE/lZ/CVS/L464ZS/lIcXqVd1pSUkZUUzeTsJKebEpRmjkolJzVWh+dUYHUZPCkQGQuuT17EeseUCI40wMv7w68qjF9BJCJpIvK6iOy2v3+iboyIjBCRtSJSIiLbRORr/uzTH1+bl8dF44bwX89vpSRMrh1pbmvnzV0VzJuYhYRZd7+3XC5hcWE27+6ppKyu2enmqMHKFQGRvQ+ezuaNdDE2WVi2xRN2ozb+9oi+DawxxuQBa+zHnXmAbxpjJgLnA18WkQI/99snbpfw8+sLSY6N5K7lG6hvbnOiGQPq/X1VNLW2s0CH5c5ocWE2XkPYz65UZ8HltoMn5ePgiU3pdfB8YnMi3D7FzbYqw7v7agLd2qDmbxAtAp6yf34KuLbzCsaYY8aYDfbP9UAJkO3nfvssIzGaX9xYyMGqRr791y0h/8ljTUk5sZFuLhg7xOmmBLUxGQlMHZHCSr24VXXH5bZC5lTwZNnBE9en4OnKtePcZMTC0ncPB2R7g4W/QZRljDkGVuAAZ7xIRURygUJgnZ/79ct5Y4Zw72XjeXnLMZ7+50Enm9KvjDGsKSljTl46MZFaWLEnSwqz2XG8PmyGbVUPxOUTPJl28KQGNHg6i3YL/zrJqnwSThVhegwiESkSka1dfC06mx2JSALwV+Drxphu/6eLyO0iUiwixRUVFWezi7Ny56fGcsmETH748nY2Ha7pt/04afuxOo7WNms1hV66eupwIlyikxbClbggMsaayZaQCYlDfYJn4Mpi3XmOmz/cek5YfXjsMYiMMfONMZO7+HoeKBORYQD29/KutiEikVghtNwYs7KH/S0zxsw0xszMyOi/4pwul/CT66aSmRjDXcs3UNPU2m/7ckrR9nJE4BKtptArafFRXDw+g+c3HtGSP+HAN3jiM+zgSbOu7RnA4OnMFYaTivwdmnsBuM3++Tbg+c4riDVV63dAiTHmp37uL6BS46P45U2FlNc3880/b8IbYm8+RSVlFI5IIT0h2ummDBqLC3Moq2vhvb1a8ifknCl43OFVySDY+BtEjwALRGQ3sMB+jIgMF5HV9joXAbcAl4jIRvvrCj/3GzCFI1O5/4qJrNlRzrK39zndnIA5XtvMliO1Oix3luZNzCRRS/6EBnFZVQo0eIKeX/1PY0wVMK+L5UeBK+yf3wGCuq/5+Qtz+fDACX782k6mj0xl1ug0p5vktzU7rCKnWk3h7MREurlyyjBe2HSUH17rIS5Kb5kxaLjcVvC4oyAiWsNmEAmbygpnIiI88plzGJEay1f+uIHKhhanm+S3NSXljEyLIy8zwemmDDrXFmbT1NrO37eVOd0UdTZikq2CodEJGkKDjAaRLSkmkl/dPJ3qpja+/tzGQX2yuqnVwzt7Kpk3MVOrKfTBrNw0slNiWanDc0oNCA0iH5OGJ/PgNZN4Z08lv/jHbqeb02dv766k1ePVagp95HIJ1xYO553dFZRryR+l+p0GUSfXnzuCJYXZ/O+a3by9u/+uY+pPa0rKSIyJ4NwQONfllMWFOXgNvLDpqNNNUSrkaRB1IiL8cPFkxmUk8PXnNnK8dnB9IvZ6Df/YUc7F4zOJdOuft6/GZSZwTk6ylvxRagDoO1UX4qIi+M3npnOyrZ2v/HEDbe2Dpyz7xtIaKhtamT9RL2L11+LCbLYfq2Pn8Xqnm6JUSNMg6sa4zET+35IpfHigmv/5+06nm9NrRdvLcLuEi/M1iPx19dThuF3Cyo/C6/5VSg00DaIzWDQtm5vOG8nSN/dRtH1wTOVdU1LOrNw0kuN0+qq/0hOi+XR+Bs9/dHRQz6JUKthpEPXgu1cVMGl4Et/8yyYOn2hyujlndKiqiZ1l9VpNIYAWF2ZzvK6Z9/dVOd0UpUKWBlEPYiLd/Prm6Xi9hruf3UCLJ3hLsxeVdFRT0GG5QFlQkEVidIROWlCqH2kQ9cKoIfH8+Lpz2FRay8MvlzjdnG6t2VFGXmYCo4bEO92UkBET6ebyKUN5desxTrYG74cQpQYzDaJeWjh5GF+cPZqn/nmQlzYH37Uldc1trNt3gnl6EWvALS7MobG1nb9vP+50U5QKSRpEZ+G+hRMoHJnCt/+6hX0VDU435zRv7qzA4zUsKNBhuUA7b3Qaw5NjtCK3Uv1Eg+gsREW4+NVN04l0C3ct3xBUt/ItKiljSHwU00akOt2UkONyCYsKs3l7dyUV9YO/IK4KciIE+Q0LAk6D6CwNT4nlp9dPY8fxer73/DanmwNAW7uXtTvKmTshE7crvP4BD5Qlhdm0e42W/FH9wx0B0YnWLcqTR4Rd9XANoj6YOz6TL88dy5+KD7NivfMXOxYfqKau2aP3HupHeVmJTM5OYpVe3KoCQcS6W2xsKiQNh6Rs6xYWkbF2jyi8aBD10Tfm53P+mDS+87ctjpeAKSopI8rtYk5euqPtCHWLC3PYeqSO3WVa8kf1gctt3SspPgOScyAhC2KSwq730xUNoj6KcLt47IZCEqIj+ffl62lo8TjSDmMMRSVlXDhuCPHRejfR/nTNqZI/OmlB9VJEFMSmQNIwK3zihkBUnHUbc3WK/jb8kJkUw2M3TuNAZSP/uXILxgx8GZi9FQ0crGrSadsDICMxmjl56Tz/0RG8WvJHdcXlsoImfogVPInDrDvHuqOcbllQ0yDy04Vj07lnQT4vbjrKM+sODfj+X99eDmg1hYGyuDCbo7XNvL9fS/4omzvKGmJLzLImGsRnQFSCNRSnekXHcgLgrovH8eGBah56cTvTclKYkpM8YPteU1LG5OwkhiXHDtg+w9mlBUNJiI7gO6u2MjsvnbzMBMZlJjIuM4H0hCi9NXs4EBdExFiTDSJjwaVvo/7S32AAuFzCz66fxpWPvc1dz67npa/MITm2/09AVjW0sP5QNV+9JK/f96UssVFuHlw0iaf/eZBVG45Q73NuMCUu8lQwWd8TyMtKYGhSjAbUYOeOtEInIsb60r9nQGkQBUhafBS/vGk61y/9J9/6yyaW3jKj39981u6swBirMKcaOEum57Bkeg7GGMrqWthdXs/usgZ2lzewt7yBV7Ye449NbafWT4iOsELJJ5zyMhPJTonFpdd9BScRu9djh4/ObOtXGkQBNGNUKt++fAI/fLmE372zny/NGdOv+yvaXsbQpBgmDU/q1/2orokIQ5NjGJocw5y8jFPLjTFUNbayu6yBPeX17Cm3QuqNXRX8xee6s5hIF+MyExiXkUBeVuKpsBqZFkeE3uZ94LkjICL24/DRXs+A0SAKsC/OHs2HB07wyCs7KByZwoxRaf2yn+a2dt7aXcHiwmwd9gkyIkJ6QjTpCdFcMHbIac/VNLWyp7zhVDjtLm/gg/0n+NvGjys2RLldjE6PZ1yWFUx59jmo3PQ4oiP0BHjAiFgTDSLjrPDRXo9jNIgCTET4789O5epfvMOXl3/Ey1+dzZCE6IDv5/19VTS1tutN8AaZlLgoZuamMTP39A8oDS0e9p4Kp3r2lDWwpbSW1VuO0XFVgNsljBoSd1o4jctMYGxGArFRGlC94nJ93OuJjNXreYKEBlE/SI6N5Nc3T2fJb97jG3/exP99/tyAnwsoKikjLsrNBWOG9LyyCnoJ0RFMHZHC1BEppy1vbmtnb4Xdgyrr6EnVU1RSfur25SIwItUKqHFZpw/1JehFztZFpaeG3AL/oVD5T/+V9pPJ2cl87+oCHli1lV+t3cNX5gVuZpsxhjUl5czJSycmUj8Jh7KYSDeThiczafjplwS0erwcqGo8LZz2lDfw9u5KWtu9p9YbnhzDuKxEO5w+njCREhfCF1i6XD4TDWL1ep5BwK8gEpE04E9ALnAA+BdjTHU367qBYuCIMeYqf/Y7WNw0ayQf7D/Bz4p2MWNUKheOC0wtuG1H6zhW28w3FuQHZHtq8ImKcJGflUh+VuJpyz3tXg6daDp1DqojpJ7dX0Vz28cBlZ4QbQ3xZSVw5ZRhnBcKPeuoOOui0ogYp1uizpK/PaJvA2uMMY+IyLftx/d1s+7XgBIgbKZ4iQgPL57C1iO1fPW5jaz+6mwyk/z/T1JUUoYIXDJBqymo00W4XYzJSGBMRgKXTvp4uddrOFJz8lQwdUw3X7XhCHmZCaERRJGx0N7qdCtUH/gbRIuAi+2fnwLeoIsgEpEc4ErgR8A9fu5zUImPjuA3n5vBNb98h6/88SOWf+k8v6fmrikpZ/rIVNL7YRKECk0ulzAiLY4RaXHM9fkAY4zBo3XzlMP8nTKSZYw5BmB/7+4j+s+B/wC83Tx/iojcLiLFIlJcUVHhZ/OCQ35WIj+6dgrr9p/gp6/v8mtbx2ub2XKklnlaW04FgIgQqdcsKYf1+C9QRIpEZGsXX4t6swMRuQooN8as7836xphlxpiZxpiZGRkZPb9gkPjMjBxuOHcEv35jL2t3lPd5O2t2lAGwQKttK6VCRI9Dc8aY+d09JyJlIjLMGHNMRIYBXb3DXgRcIyJXADFAkog8Y4z5XJ9bPUh9/5pJbCqt5Rt/3sjLX51DdsrZFyot2l7GyLQ4xmUm9EMLlVJq4PnbJ38BuM3++Tbg+c4rGGP+0xiTY4zJBW4A/hGOIQTWVNxf3zwdT7vhy8s30OrpcaTyNE2tHt7dW8X8iVlaTUEpFTL8DaJHgAUishtYYD9GRIaLyGp/GxeKRqfH8+hnzmHj4RoeeWXHWb327d2VtHq8zC/Q80NKqdDh16w5Y0wVMK+L5UeBK7pY/gbWzLqwduU5w/jwQC5Pvrufc3NTuXzKsF69rmh7GYkxEZyb2z/165RSygk6XcYh918xkakjUviPFZs5UNnY4/rtXsM/dpQzd3ymznJSSoUUfUdzSFSEi1/dVIjLJdy1fAPNbe1nXH/j4RqqGlt12rZSKuRoEDkoJzWOn/7LVLYfq+MHL24/47prSsqIcAkX52sQKaVCiwaRw+ZNzOLOT4/ljx8cYtVHpd2uV1RSxqzRaSTH6T1TlFKhRYMoCNx7aT6zctO4f+VWdpfVf+L5Q1VN7CprYJ5exKqUCkEaREEgwu3iFzcVEhfl5t+Xb6Cp1XPa80UlVjWF+Xp+SCkVgjSIgkRWUgz/e0MheysaeGDVVoz5uBBlUUkZeZkJjBoS72ALlVKqf2gQBZHZeel8fV4+qz46wnMfHgag9mQbH+w/obcEV0qFLL1Da5C5+5JxFB88wfde2MaU7GT2VTbi8Rrm6/khpVSI0h5RkHG7hJ9fP420uCi+/OwG/vbREYbERzFtRIrTTVNKqX6hQRSEhiRE84ubCimtPsk/dpRzyYRM3C4tcqpUWBABCa+35vA62kHk3Nw07ls4HoBLJw11uDVKqX4lAtGJkDQM0saAO7zOmoTX0Q4y/zZnDBeOTWfS8CSnm6KUCjQRiIqHqATrexjf2kWDKIiJCJOzk51uhlIqUE6Fjx1AYRw+vjSIlFKqP4lAZBxEJ0BkPLj0jEhnGkRKKRVoIhAZaw+7JWj49ECDSCmlAkHDp880iJRSyh+RsdawW1QCuNxOt2ZQ0iBSSqmzFRlrTTiITtTwCQANIqWU6o3ImI+H3cLsOp/+pr9NpZTqTkS0PeyWqOHTj/Q3q5RSvk6FTwK49Y7IA0GDSCmlIqKs4IlO1PBxgAaRUio8uSOt4IlKsIJIOUaDSCkVPjR8gpIGkVIqtLkjrMkG0QnW+R8VdDSIlFKhpyN8ouKtadcqqPlVg0JE0kTkdRHZbX9P7Wa9FBFZISI7RKRERC7wZ79KKfUJLjfEpkByDqTmQvwQDaFBwt9iSN8G1hhj8oA19uOu/C/wqjFmAjAVKPFzv0opdbq4NIhP1/AZhPwNokXAU/bPTwHXdl5BRJKATwG/AzDGtBpjavzcr1JKqRDhbxBlGWOOAdjfM7tYZwxQAfxeRD4SkSdEJN7P/SqllAoRPQaRiBSJyNYuvhb1ch8RwHTgN8aYQqCR7ofwEJHbRaRYRIorKip6uQullFKDVY+z5owx87t7TkTKRGSYMeaYiAwDyrtYrRQoNcassx+v4AxBZIxZBiwDmDlzpumpfUoppQY3f4fmXgBus3++DXi+8wrGmOPAYREZby+aB2z3c79KKaVChL9B9AiwQER2Awvsx4jIcBFZ7bPeV4DlIrIZmAY87Od+lVJKhQi/Lmg1xlRh9XA6Lz8KXOHzeCMw0599KaWUCk16U3WllFKO0iBSSinlKDEmeCemiUgFcLCPL08HKgPYHCeFyrGEynGAHkswCpXjAP+OZZQxJiOQjelvQR1E/hCRYmNMSJyXCpVjCZXjAD2WYBQqxwGhdSy9oUNzSimlHKVBpJRSylGhHETLnG5AAIXKsYTKcYAeSzAKleOA0DqWHoXsOSKllFKDQyj3iJRSSg0CIR1EIvJ9ETkiIhvtryt6flXwEpF7RcSISLrTbekrEXlIRDbbf4+/i8hwp9vUVyLyY/uuw5tFZJWIpDjdpr4QketEZJuIeEVkUM7UEpGFIrJTRPaISLdFlYOdiDwpIuUistXptgykkA4i28+MMdPsr9U9rx6cRGQEVj2/Q063xU8/NsacY4yZBrwEfNfh9vjjdWCyMeYcYBfwnw63p6+2AkuAt5xuSF+IiBv4FXA5UADcKCIFzraqz/4PWOh0IwZaOARRqPgZ8B/AoD6pZ4yp83kYzyA+HmPM340xHvvh+0COk+3pK2NMiTFmp9Pt8MMsYI8xZp8xphV4Duvu0YOOMeYt4ITT7Rho4RBEd9tDJ0+KSKrTjekLEbkGOGKM2eR0WwJBRH4kIoeBmxncPSJfXwBecboRYSobOOzzuNRepgYJv6pvBwMRKQKGdvHUA8BvgIewPnU/BPwE6w0j6PRwHPcDlw5si/ruTMdijHneGPMA8ICI/CdwN/C9AW3gWejpWOx1HgA8wPKBbNvZ6M1xDGLSxbJB29MOR4M+iM50B1lfIvJbrHMSQam74xCRKcBoYJOIgDX8s0FEZtk3HQw6vf2bAM8CLxPEQdTTsYjIbcBVwDwTxNdCnMXfZDAqBUb4PM4BjjrUFtUHIT00Z9++vMNirJOyg4oxZosxJtMYk2uMycX6Tzc9WEOoJyKS5/PwGmCHU23xl4gsBO4DrjHGNDndnjD2IZAnIqNFJAq4Aevu0WqQCOkLWkXkD1h3hDXAAeAOY8wxJ9vkLxE5AMw0xgzKKsMi8ldgPODFqqx+pzHmiLOt6hsR2QNEA1X2oveNMXc62KQ+EZHFwC+ADKAG2GiMuczRRp0l+9KMnwNu4EljzI+cbVHfiMgfgYuxqm+XAd8zxvzO0UYNgJAOIqWUUsEvpIfmlFJKBT8NIqWUUo7SIFJKKeUoDSKllFKO0iBSSinlKA0iFRJEZIhPlfXjPlXXG0Tk1/2wvztF5NazfM0bg7W6tVL9adBXVlAKwBhThXXNGCLyfaDBGPM//bi/x/tr20qFG+0RqZAmIheLyEv2z98Xkafs+yAdEJElIvLfIrJFRF4VkUh7vRki8qaIrBeR1zpV6MBnW/faP78hIo+KyAcisktE5tjLY0XkObvo7p+AWJ/XXyoi/xSRDSLyFxFJEJFRIrJbRNJFxCUib4vIoKkxqFRfaRCpcDMWuBLrNgHPAGuNMVOAk8CVdhj9AvisMWYG8CTQm6v0I4wxs4Cv83HtvH8Hmuz7Ff0ImAFg39jwO8B8Y8x0oBi4xxhzEHgUeBz4JrDdGPN3/w9ZqeCmQ3Mq3LxijGkTkS1Y5WBetZdvAXKxyg9NBl63i8y6gd6UhVppf19vbwfgU8BjAMaYzSKy2V5+PtYN3N619xEF/NNe7wkRuQ64E3uoUalQp0Gkwk0LgDHGKyJtPhWzvVj/HwTYZoy5oC/bBdo5/f9VVzW0BHjdGHPjJ54QiePjG+wlAPVn2Q6lBh0dmlPqdDuBDBG5AEBEIkVkUh+39RbWzf8QkcnAOfby94GLRGSc/VyciOTbzz2KdV+j7wK/7eN+lRpUNIiU8mHfavqzwKMisgnYCFzYx839Bkiwh+T+A/jA3kcF8Hngj/Zz7wMTROTTwLnAo8aY5UCriPyrH4ej1KCg1beVUko5SntESimlHKVBpJRSylEaREoppRylQaSUUspRGkRKKaUcpUGklFLKURpESimlHKVBpJRSylH/H3saAbVPRnBmAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pytorch_lightning import Trainer\n",
    "\n",
    "model = FullyConnectedForDistributionLossModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2, log_interval=1)\n",
    "trainer = Trainer(fast_dev_run=True)\n",
    "trainer.fit(model, train_dataloader=dataloader, val_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.7 64-bit ('base': conda)",
   "language": "python",
   "name": "python37764bitbaseconda4052e86d6f894f0ea94517897490b6df"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
